[WIP] Weekly update of repo (#8390)

* [WIP] Fix compilation after nd4j changes (#37)

* Fix compilation.

* Some tests fixed

* Disable tests temporarily.

* Restored test

* Tests restored.

* Test restored.

* [WIP] perf tests (#40)

* special maxpool test

Signed-off-by: raver119 <raver119@gmail.com>

* special maxpool test

Signed-off-by: raver119 <raver119@gmail.com>

* Shyrma bnorm bp (#41)

Batchnorm backprop mkldnn

* Add SameDiff memory reuse memory manager (array cache) (#39)

* Attention op comments

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ArrayCacheMemoryMgr - first pass

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweak array cache for use with SameDiff identity arrays

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ArrayCacheMemoryMgr javadoc and properly get max memory

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* LRU cache policy + add tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Resize arrays internally if required for ArrayCacheMemoryMgr

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Test improvement

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff op runtime benchmarking listener (#42)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* INLINE_LOOPS for windows

Signed-off-by: raver119 <raver119@gmail.com>

* [WIP] ThreadPool (#8)

This PR removes OpenMP use in 95% of cases
master
raver119 2019-11-13 17:15:18 +03:00 committed by GitHub
parent c0f91e5c3c
commit 6de00bf75f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
293 changed files with 9700 additions and 12064 deletions

View File

@ -98,6 +98,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
SDVariable diff = sd.f().squaredDifference(a1, label); SDVariable diff = sd.f().squaredDifference(a1, label);
SDVariable lossMse = diff.mean(); SDVariable lossMse = diff.mean();
lossMse.markAsLoss();
IUpdater updater; IUpdater updater;
double lr; double lr;

View File

@ -35,6 +35,7 @@ import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.NDArrayFactory; import org.nd4j.linalg.factory.NDArrayFactory;
@ -482,23 +483,12 @@ public class ConvolutionUtils {
return reshape5dTo2d(format, mask, workspaceMgr, type); return reshape5dTo2d(format, mask, workspaceMgr, type);
} else { } else {
//Need to broadcast first //Need to broadcast first
IntArrayList broadcastDims = new IntArrayList();
for(int i=0; i<mask.rank(); i++ ){
if(mask.size(i) == label.size(i)){
if((format == Convolution3D.DataFormat.NCDHW && i == 1) || (format == Convolution3D.DataFormat.NDHWC && i == 4)){
//Skip channels dimension
continue;
}
broadcastDims.add(i);
}
}
long[] lShape = label.shape().clone(); long[] lShape = label.shape().clone();
int channelIdx = format == Convolution3D.DataFormat.NCDHW ? 1 : 4; int channelIdx = format == Convolution3D.DataFormat.NCDHW ? 1 : 4;
lShape[channelIdx] = mask.size(channelIdx); //Keep existing channel size lShape[channelIdx] = mask.size(channelIdx); //Keep existing channel size
INDArray bMask = workspaceMgr.createUninitialized(type, mask.dataType(), lShape, 'c'); INDArray bMask = workspaceMgr.createUninitialized(type, mask.dataType(), lShape, 'c');
int[] bcDims = broadcastDims.toIntArray(); Nd4j.exec(new Assign(new INDArray[]{bMask, mask}, new INDArray[]{bMask}));
Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, bcDims));
return reshape5dTo2d(format, bMask, workspaceMgr, type); return reshape5dTo2d(format, bMask, workspaceMgr, type);
} }
} }

View File

@ -16,17 +16,20 @@ endif()
# -fsanitize=address # -fsanitize=address
# -fsanitize=leak # -fsanitize=leak
if (APPLE) if (ANDROID_BUILD)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else")
elseif (APPLE)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true")
elseif(WIN32) elseif(WIN32)
set(X86_BUILD true) set(X86_BUILD true)
if (NOT CUDA_BLAS) if (CUDA_BLAS)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true") set(CMAKE_CXX_FLAGS_RELEASE " /O2 -D_RELEASE=true /wd4804")
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fmax-errors=2")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305") set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC -std=c++11 -fmax-errors=2")
endif() endif()
else() else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true") set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
@ -75,6 +78,9 @@ if(NOT CUDA_BLAS)
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ") message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
add_definitions(-D__EXTERNAL_BLAS__=true) add_definitions(-D__EXTERNAL_BLAS__=true)
elseif(WIN32)
message("BLAS not found, using downloaded OpenBLAS instead")
add_definitions(-D__EXTERNAL_BLAS__=true)
endif() endif()
else() else()
# if we have externally provided OPENBLAS_PATH - let's use it # if we have externally provided OPENBLAS_PATH - let's use it

View File

@ -5,7 +5,7 @@ project(mkldnn-download NONE)
include(ExternalProject) include(ExternalProject)
ExternalProject_Add(mkldnn ExternalProject_Add(mkldnn
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
GIT_TAG v1.0.2 GIT_TAG v1.0.4
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

View File

@ -30,8 +30,8 @@ if(APPLE)
endif() endif()
if (APPLE_BUILD) if (APPLE_BUILD)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAPPLE_BUILD=true") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAPPLE_BUILD=true") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10")
endif() endif()
if (ANDROID_BUILD) if (ANDROID_BUILD)
@ -92,11 +92,13 @@ ELSE()
IF(${EXTENSION} MATCHES "avx512") IF(${EXTENSION} MATCHES "avx512")
message("Building AVX512 binary...") message("Building AVX512 binary...")
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked # we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
message("Current CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
endif() endif()
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}") if (NOT WIN32)
# we don't want this definition for msvc
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}")
endif()
ENDIF() ENDIF()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
@ -109,7 +111,7 @@ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel")
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
# using Visual Studio C++ # using Visual Studio C++
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /w") set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc ${ARCH_TUNE}")
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
# using GCC # using GCC
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
@ -283,8 +285,8 @@ if(CUDA_BLAS)
if(WIN32) if(WIN32)
message("CUDA on Windows: enabling /EHsc") message("CUDA on Windows: enabling /EHsc")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj") SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14")
endif() endif()
@ -322,7 +324,7 @@ elseif(CPU_BLAS)
endif() endif()
if (X86_BUILD) if (X86_BUILD)
#we disable platform optimizations for certains files # we disable platform optimizations for certains files for linux/macos
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
endif() endif()
@ -342,7 +344,16 @@ elseif(CPU_BLAS)
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>) add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
endif() endif()
#if(WIN32)
# message("CPU on Windows: enabling /EHsc")
# SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
# SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14")
#endif()
# we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS # we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS
if (NOT BLAS_LIBRARIES)
set(BLAS_LIBRARIES "")
endif()
target_link_libraries(${LIBND4J_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) target_link_libraries(${LIBND4J_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}") if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")

View File

@ -24,6 +24,8 @@
#include <string> #include <string>
#include "Environment.h" #include "Environment.h"
#include <helpers/StringUtils.h> #include <helpers/StringUtils.h>
#include <thread>
#include <helpers/logger.h>
#ifdef _OPENMP #ifdef _OPENMP
@ -49,6 +51,7 @@ namespace nd4j {
_precBoost.store(false); _precBoost.store(false);
_leaks.store(false); _leaks.store(false);
_dataType.store(nd4j::DataType::FLOAT32); _dataType.store(nd4j::DataType::FLOAT32);
_maxThreads = std::thread::hardware_concurrency();
#ifndef ANDROID #ifndef ANDROID
const char* omp_threads = std::getenv("OMP_NUM_THREADS"); const char* omp_threads = std::getenv("OMP_NUM_THREADS");
@ -86,9 +89,7 @@ namespace nd4j {
cudaSetDevice(0); cudaSetDevice(0);
delete[] devProperties; delete[] devProperties;
#else #else
#ifdef _OPENMP
omp_set_nested(1);
#endif
#endif #endif
} }

View File

@ -26,6 +26,7 @@
#include <indexing/IndicesList.h> #include <indexing/IndicesList.h>
#include <graph/Intervals.h> #include <graph/Intervals.h>
#include <array/DataType.h> #include <array/DataType.h>
#include <array/DataTypeUtils.h>
#include <stdint.h> #include <stdint.h>
#include <array/ArrayOptions.h> #include <array/ArrayOptions.h>
#include <array/ArrayType.h> #include <array/ArrayType.h>
@ -1678,7 +1679,6 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
size_t NDArray::sizeOfT() const { size_t NDArray::sizeOfT() const {
return DataTypeUtils::sizeOfElement(_dataType); return DataTypeUtils::sizeOfElement(_dataType);
} }

View File

@ -2478,7 +2478,6 @@ double NDArray::getTrace() const {
double sum = 0.; double sum = 0.;
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int i = 0; i < minDim; ++i) for(int i = 0; i < minDim; ++i)
sum += e<double>(i * offset); sum += e<double>(i * offset);
@ -3275,7 +3274,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
// regular numeric types // regular numeric types
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0 NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
ExtraArguments extras({eps}); ExtraArguments extras({0.0, 0.0, eps});
NDArray::prepareSpecialUse({&tmp}, {this, other}); NDArray::prepareSpecialUse({&tmp}, {this, other});
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
@ -3288,7 +3287,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
synchronize("NDArray::equalsTo"); synchronize("NDArray::equalsTo");
if (tmp.e<int>(0) > 0) if (tmp.e<Nd4jLong>(0) != 0)
return false; return false;
return true; return true;

View File

@ -24,10 +24,10 @@
#include <types/types.h> #include <types/types.h>
#include <dll.h> #include <dll.h>
#include <loops/aggregates.h>
#include <ops/specials.h> #include <ops/specials.h>
#include <ops/specials_sparse.h> #include <ops/specials_sparse.h>
#include <execution/LaunchContext.h> #include <execution/LaunchContext.h>
#include <array/ArrayOptions.h>
/** /**
* Native op executioner: * Native op executioner:
@ -624,10 +624,6 @@ static void execTransformBool(nd4j::LaunchContext *lc,
void *vrealArguments, void *vrealArguments,
int numRealArguments) { int numRealArguments) {
auto arguments = reinterpret_cast<X **>(varguments);
auto realArguments = reinterpret_cast<X *>(vrealArguments);
functions::aggregate::AggregatedFunction<X>::exec(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
} }

View File

@ -55,7 +55,6 @@
#define ND4J_EXPORT #define ND4J_EXPORT
#endif #endif
#include <dll.h> #include <dll.h>
#include <helpers/BlasHelper.h>
/* /*
int tad_threshold = 1; int tad_threshold = 1;
@ -1430,7 +1429,11 @@ static const char* getNpyArrayNameFromMap(void *map, int index){
for(; it != end; ++it, ++cnt){ for(; it != end; ++it, ++cnt){
if (cnt == index){ if (cnt == index){
// FIXME: @fariz, this is a leak! // FIXME: @fariz, this is a leak!
#ifdef _MSC_VER
return const_cast<const char *>(_strdup(it->first.c_str()));
#else
return const_cast<const char *>(strdup(it->first.c_str())); return const_cast<const char *>(strdup(it->first.c_str()));
#endif
} }
} }
throw std::runtime_error("No array at index."); throw std::runtime_error("No array at index.");

View File

@ -98,24 +98,27 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo()); const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo());
std::vector<Nd4jLong> coords(zRank);
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong i = 0; i < zLen; ++i) { Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) {
shape::index2coords(i, target->getShapeInfo(), coords);
const auto zOffset = shape::getOffset(target->getShapeInfo(), coords);
shape::index2coords(i, target->getShapeInfo(), coords.data()); // if( (row + upper < col) || (row + lower > col) )
const auto zOffset = shape::getOffset(target->getShapeInfo(), coords.data()); if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
z[zOffset] = value;
else if (this != target) { // when this and target are different arrays
if (xRank != zRank)
coords[0] = coords[1];
// if( (row + upper < col) || (row + lower > col) ) const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords);
if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) z[zOffset] = x[xOffset];
z[zOffset] = value; }
else if(this != target) { // when this and target are different arrays
if(xRank != zRank)
coords[0] = coords[1];
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords.data());
z[zOffset] = x[xOffset];
} }
} };
samediff::Threads::parallel_for(func, 0, zLen);
} }
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES);
@ -140,7 +143,7 @@ void NDArray::setIdentity() {
minDim = shape[i]; minDim = shape[i];
float v = 1.0f; float v = 1.0f;
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
for(int i = 0; i < minDim; ++i) for(int i = 0; i < minDim; ++i)
templatedSet<float>(buffer(), i*offset, this->dataType(), &v); templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
} }
@ -151,12 +154,15 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
auto x = reinterpret_cast<T *>(xBuffer); auto x = reinterpret_cast<T *>(xBuffer);
auto y = reinterpret_cast<T *>(yBuffer); auto y = reinterpret_cast<T *>(yBuffer);
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(schedule(static)) auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong i = 0; i < length; ++i) { for (auto i = start; i < stop; i += increment) {
auto temp = x[i]; auto temp = x[i];
x[i] = y[i]; x[i] = y[i];
y[i] = temp; y[i] = temp;
} }
};
samediff::Threads::parallel_for(func, 0, length);
} }
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
@ -262,21 +268,26 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
auto xType = this->dataType(); auto xType = this->dataType();
if(result.ordering() == 'c') { // ews == 1 always here if(result.ordering() == 'c') { // ews == 1 always here
PRAGMA_OMP_PARALLEL_FOR_SIMD auto func = PRAGMA_THREADS_FOR {
for(Nd4jLong i = 0; i < resultLen; ++i) { for (auto i = start; i < stop; i += increment) {
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo()); auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
}
};
} samediff::Threads::parallel_for(func, 0, resultLen);
} }
else { else {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto func = PRAGMA_THREADS_FOR {
for(Nd4jLong i=0; i<resultLen; ++i) { for (auto i = start; i < stop; i += increment) {
auto xOffset = result.getOffset(i); auto xOffset = result.getOffset(i);
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo()); auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
} }
};
samediff::Threads::parallel_for(func, 0, resultLen);
} }
result.tickWriteHost(); result.tickWriteHost();
return result; return result;
@ -337,14 +348,7 @@ void NDArray::tile(NDArray& target) const {
// looping through _buffer goes automatically by means of getSubArrayIndex applying // looping through _buffer goes automatically by means of getSubArrayIndex applying
const auto ews = target.ews(); const auto ews = target.ews();
const auto targetLen = target.lengthOf(); const auto targetLen = target.lengthOf();
if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here if(target.ordering() == 'c' && ews >= 1) {
for (Nd4jLong i = 0; i < targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
else if(target.ordering() == 'c' && ews > 1) {
for(Nd4jLong i=0; i<targetLen; ++i) { for(Nd4jLong i=0; i<targetLen; ++i) {
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo()); auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
@ -373,30 +377,30 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
const int zLen = output.lengthOf(); // xLen <= zLen const int zLen = output.lengthOf(); // xLen <= zLen
const int repSize = repeats.size(); const int repSize = repeats.size();
std::vector<Nd4jLong> coords(rank);
// loop through input array // loop through input array
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(coords)) auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong i = 0; i < zLen; ++i) { Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) {
shape::index2coords(i, output.getShapeInfo(), coords);
shape::index2coords(i, output.getShapeInfo(), coords.data()); const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords.data()); if (repSize > 1) {
for (uint j = 0; j < repSize; ++j) {
if(repSize > 1) { coords[axis] -= repeats[j];
for (uint j = 0; j < repSize; ++j) { if (coords[axis] < 0) {
coords[axis] -= repeats[j]; coords[axis] = j;
if (coords[axis] < 0) { break;
coords[axis] = j; }
break;
} }
} } else
} coords[axis] /= repeats[0];
else
coords[axis] /= repeats[0];
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords.data())]; z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords)];
} }
};
samediff::Threads::parallel_for(func, 0, zLen);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -32,33 +32,40 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) { if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) for (auto e = start; e < stop; e += increment)
z[e] = func(f[e], s[e], t[e]); z[e] = func(f[e], s[e], t[e]);
};
samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
auto tOffset = this->getOffset(e); f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
auto uOffset = second->getOffset(e); }
auto vOffset = third->getOffset(e); };
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} else { } else {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
auto zOffset = target->getOffset(e);
auto tOffset = this->getOffset(e); z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
auto uOffset = second->getOffset(e); }
auto vOffset = third->getOffset(e); };
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} }
} }
} }
@ -103,31 +110,38 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) for (auto e = start; e < stop; e += increment)
z[e] = func(f[e], s[e]); z[e] = func(f[e], s[e]);
};
samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto xOffset = this->getOffset(e); f[xOffset] = func(f[xOffset], s[yOffset]);
auto yOffset = other->getOffset(e); }
};
f[xOffset] = func(f[xOffset], s[yOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} else { } else {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
auto xOffset = this->getOffset(e); z[zOffset] = func(f[xOffset], s[yOffset]);
auto yOffset = other->getOffset(e); }
auto zOffset = target->getOffset(e); };
z[zOffset] = func(f[xOffset], s[yOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} }
} }
} }
@ -161,29 +175,36 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (int e = 0; e < _length; e++) for (auto e = start; e < stop; e += increment)
z[e] = func(f[e]); z[e] = func(f[e]);
};
samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (int e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto xOffset = this->getOffset(e); f[xOffset] = func(f[xOffset]);
}
};
f[xOffset] = func(f[xOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} else { } else {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (int e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
auto xOffset = this->getOffset(e); z[zOffset] = func(f[xOffset]);
auto zOffset = target->getOffset(e); }
};
z[zOffset] = func(f[xOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} }
} }
} }
@ -217,29 +238,36 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) for (auto e = start; e < stop; e += increment)
z[e] = func(e, f[e]); z[e] = func(e, f[e]);
};
samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto xOffset = this->getOffset(e); f[xOffset] = func(e, f[xOffset]);
}
};
f[xOffset] = func(e, f[xOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} else { } else {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
auto xOffset = this->getOffset(e); z[zOffset] = func(e, f[xOffset]);
auto zOffset = target->getOffset(e); }
};
z[zOffset] = func(e, f[xOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} }
} }
} }
@ -282,31 +310,38 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < _length; e++) for (auto e = start; e < stop; e += increment)
z[e] = func((Nd4jLong) e, f[e], s[e]); z[e] = func((Nd4jLong) e, f[e], s[e]);
};
samediff::Threads::parallel_for(loop, 0, _length);
} else { } else {
if (f == z) { if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (int e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto xOffset = this->getOffset(e); f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
auto yOffset = other->getOffset(e); }
};
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} else { } else {
PRAGMA_OMP_PARALLEL_FOR_SIMD auto loop = PRAGMA_THREADS_FOR {
for (int e = 0; e < _length; e++) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
auto xOffset = this->getOffset(e); z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
auto yOffset = other->getOffset(e); }
auto zOffset = target->getOffset(e); };
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); samediff::Threads::parallel_for(loop, 0, _length);
}
} }
} }
} }

View File

@ -20,6 +20,8 @@
#include "NativeOpExecutioner.h" #include "NativeOpExecutioner.h"
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h>
#include <pairwise_bool.h> #include <pairwise_bool.h>
#include <broadcasting_bool.h> #include <broadcasting_bool.h>
#include <scalar_bool.h> #include <scalar_bool.h>
@ -50,11 +52,14 @@
#include <loops/random.h> #include <loops/random.h>
#include <pointercast.h> #include <pointercast.h>
#include <exceptions/datatype_exception.h> #include <exceptions/datatype_exception.h>
#include <array/TadPack.h>
#include <helpers/ConstantTadHelper.h>
#ifdef _OPENMP #ifdef _OPENMP
#include <omp.h> #include <omp.h>
#include <helpers/ConstantTadHelper.h>
#endif #endif
@ -78,9 +83,7 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, int op
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong *dZShapeInfo) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -111,9 +114,7 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -149,9 +150,7 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
@ -160,7 +159,16 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen;
samediff::Threads::parallel_tad(func, 0, numTads);
#endif #endif
} }
@ -179,9 +187,7 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
#ifdef _OPENMP
omp_set_nested(1);
#endif
if (!nd4j::Environment::getInstance()->isExperimentalBuild()) if (!nd4j::Environment::getInstance()->isExperimentalBuild())
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType) if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
@ -190,7 +196,15 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen;
samediff::Threads::parallel_tad(func, 0, numTads);
#endif #endif
} }
@ -208,15 +222,21 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen;
samediff::Threads::parallel_tad(func, 0, numTads);
} }
void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
@ -231,9 +251,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
@ -243,7 +261,15 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
if (yType != xType || nd4j::DataType::BOOL != zType) if (yType != xType || nd4j::DataType::BOOL != zType)
throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType); throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen;
samediff::Threads::parallel_tad(func, 0, numTads);
} }
@ -260,9 +286,7 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
@ -274,7 +298,15 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
if (!nd4j::DataTypeUtils::isZ(zType)) if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen;
samediff::Threads::parallel_tad(func, 0, numTads);
} }
void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
@ -289,21 +321,27 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType);
if (!nd4j::DataTypeUtils::isZ(zType)) if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt,::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = yLen / xLen;
samediff::Threads::parallel_tad(func, 0, numTads);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -328,9 +366,7 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
@ -339,7 +375,15 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform,
::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop),
LIBND4J_TYPES);
};
auto zLen = shape::length(hZShapeInfo);
samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
#endif #endif
} }
@ -353,9 +397,7 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
@ -367,7 +409,13 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
if (zType != nd4j::DataType::BOOL) if (zType != nd4j::DataType::BOOL)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
auto zLen = shape::length(hZShapeInfo);
samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -380,9 +428,7 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
@ -394,7 +440,13 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
if (!nd4j::DataTypeUtils::isZ(zType)) if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), INTEGER_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), INTEGER_TYPES);
};
auto zLen = shape::length(hZShapeInfo);
samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -417,14 +469,22 @@ void NativeOpExecutioner::execReduceFloat(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES); // nothing to do here if result is empty
if (shape::isEmpty(hZShapeInfo))
return;
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
};
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -437,14 +497,22 @@ void NativeOpExecutioner::execReduceSame(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES); // nothing to do here if result is empty
if (shape::isEmpty(hZShapeInfo))
return;
auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES);
};
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -457,14 +525,22 @@ void NativeOpExecutioner::execReduceBool(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES); // nothing to do here if result is empty
if (shape::isEmpty(hZShapeInfo))
return;
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -477,14 +553,22 @@ void NativeOpExecutioner::execReduceLong(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES); // nothing to do here if result is empty
if (shape::isEmpty(hZShapeInfo))
return;
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, LONG_TYPES);
};
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -503,9 +587,7 @@ void NativeOpExecutioner::execReduceFloatScalar(nd4j::LaunchContext *lc,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong *dZShapeInfo) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -521,9 +603,7 @@ void NativeOpExecutioner::execReduceSameScalar(nd4j::LaunchContext *lc,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong *dZShapeInfo) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
@ -539,9 +619,7 @@ void NativeOpExecutioner::execReduceBoolScalar(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong *dZShapeInfo) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -557,9 +635,7 @@ void NativeOpExecutioner::execReduceLongScalar(nd4j::LaunchContext *lc,
void *extraParams, void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong *dZShapeInfo) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -591,10 +667,6 @@ void NativeOpExecutioner::execReduce3Scalar(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong *dZShapeInfo) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -623,15 +695,13 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo) { void *dZ, Nd4jLong *dZShapeInfo) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, nullptr, 1), LIBND4J_TYPES, FLOAT_TYPES); //BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, nullptr, 0), LIBND4J_TYPES, FLOAT_TYPES);
NativeOpExecutioner::execReduce3Scalar(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -647,14 +717,31 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength), LIBND4J_TYPES, FLOAT_TYPES); const auto xLen = shape::length(hXShapeInfo);
const auto yLen = shape::length(hYShapeInfo);
nd4j::TadPack tadPack;
if(xLen == yLen) {
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
}
else if(yLen > xLen) {
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength);
}
else {
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
}
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
};
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
} }
@ -671,15 +758,19 @@ void NativeOpExecutioner::execReduce3All(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
// BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES);
// TODO: make it 2d
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
};
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -696,15 +787,31 @@ void NativeOpExecutioner::execReduce3TAD(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) { Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES); const auto xLen = shape::length(hXShapeInfo);
// BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES); const auto yLen = shape::length(hYShapeInfo);
nd4j::TadPack tadPack;
if(xLen == yLen) {
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
}
else if(yLen > xLen) {
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength);
}
else {
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
}
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
};
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
} }
@ -729,9 +836,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
void *hScalar, Nd4jLong *hScalarShapeInfo, void *hScalar, Nd4jLong *hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo, void *dScalar, Nd4jLong *dScalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
@ -743,7 +848,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, allowParallelism), LIBND4J_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform,::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES);
};
auto zLen = shape::length(hZShapeInfo);
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
#endif #endif
} }
@ -760,9 +871,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
@ -774,7 +883,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
};
auto yLen = shape::length(hScalarShapeInfo);
samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min<int>(yLen, nd4j::Environment::getInstance()->maxThreads()));
#endif #endif
} }
@ -789,9 +904,7 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dSscalarShapeInfo, void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
@ -803,7 +916,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
if (zType != nd4j::DataType::BOOL) if (zType != nd4j::DataType::BOOL)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, BOOL_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
auto zLen = shape::length(hZShapeInfo);
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -819,9 +938,7 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
@ -833,7 +950,12 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
if (zType != nd4j::DataType::BOOL) if (zType != nd4j::DataType::BOOL)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
auto yLen = shape::length(hScalarShapeInfo);
samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min<int>(yLen, nd4j::Environment::getInstance()->maxThreads()));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -847,9 +969,7 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dSscalarShapeInfo, void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
@ -861,7 +981,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
if (!nd4j::DataTypeUtils::isZ(zType)) if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType);
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), INTEGER_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), INTEGER_TYPES);
};
auto zLen = shape::length(hZShapeInfo);
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -877,9 +1003,7 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
@ -891,7 +1015,12 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
if (!nd4j::DataTypeUtils::isZ(zType)) if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES);
};
auto yLen = shape::length(hScalarShapeInfo);
samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min<int>(yLen, nd4j::Environment::getInstance()->maxThreads()));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -912,9 +1041,7 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -940,9 +1067,7 @@ void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected) { bool biasCorrected) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -972,10 +1097,6 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool biasCorrected) { bool biasCorrected) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -1002,14 +1123,14 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES); auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
};
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1021,14 +1142,14 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES); auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
};
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1040,14 +1161,14 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets, allowParallelism), LIBND4J_TYPES, LIBND4J_TYPES); auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
};
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1059,14 +1180,14 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), LIBND4J_TYPES); auto func = PRAGMA_THREADS_DO {
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
};
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1078,14 +1199,14 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), FLOAT_TYPES); auto func = PRAGMA_THREADS_DO {
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
};
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -1095,9 +1216,7 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -1116,9 +1235,7 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
@ -1139,9 +1256,7 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraArguments) { void *extraArguments) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);

View File

@ -28,7 +28,6 @@
#include <templatemath.h> #include <templatemath.h>
#include <types/float8.h> #include <types/float8.h>
#include <loops/type_conversions.h> #include <loops/type_conversions.h>
#include <loops/aggregates.h>
#include <helpers/helper_ptrmap.h> #include <helpers/helper_ptrmap.h>
#include <helpers/logger.h> #include <helpers/logger.h>
#include <pointercast.h> #include <pointercast.h>
@ -36,6 +35,7 @@
#include <types/types.h> #include <types/types.h>
#include <ops/declarable/helpers/transforms.h> #include <ops/declarable/helpers/transforms.h>
#include <exceptions/allocation_exception.h> #include <exceptions/allocation_exception.h>
#include <helpers/BlasHelper.h>
#include <fcntl.h> #include <fcntl.h>
@ -75,6 +75,7 @@ bool experimentalSupport = false;
#include <performance/benchmarking/BenchmarkSuit.h> #include <performance/benchmarking/BenchmarkSuit.h>
#include <performance/benchmarking/FullBenchmarkSuit.h> #include <performance/benchmarking/FullBenchmarkSuit.h>
#include <performance/benchmarking/LightBenchmarkSuit.h> #include <performance/benchmarking/LightBenchmarkSuit.h>
#include <execution/Threads.h>
#ifdef CPU_FEATURES #ifdef CPU_FEATURES
#include <cpuinfo_x86.h> #include <cpuinfo_x86.h>
@ -1152,10 +1153,7 @@ void initializeFunctions(Nd4jPointer *functions) {
* @param flags optional parameter * @param flags optional parameter
*/ */
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
Nd4jPointer pointer = (Nd4jPointer) malloc(memorySize); return reinterpret_cast<Nd4jPointer>(new int8_t[memorySize]);
if (pointer == 0)
return 0L;
return pointer;
} }
/** /**
@ -1179,7 +1177,7 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
* @param pointer pointer that'll be freed * @param pointer pointer that'll be freed
*/ */
int freeHost(Nd4jPointer pointer) { int freeHost(Nd4jPointer pointer) {
free(reinterpret_cast<void *>(pointer)); delete[] reinterpret_cast<int8_t *>(pointer);
return 1L; return 1L;
} }
@ -1364,37 +1362,37 @@ void pullRowsGeneric(void *vx,
int elementsPerThread = n / TAD_THRESHOLD; int elementsPerThread = n / TAD_THRESHOLD;
int _threads = nd4j::math::nd4j_max<int>(1, elementsPerThread); int _threads = nd4j::math::nd4j_max<int>(1, elementsPerThread);
_threads = nd4j::math::nd4j_min<int>(_threads, omp_get_max_threads()); _threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads());
PRAGMA_OMP_PARALLEL_FOR_THREADS(_threads) auto func = PRAGMA_THREADS_FOR {
for (int idx = 0; idx < n; idx++) { for (auto idx = start; idx < stop; idx += increment) {
auto xTadOffsetForBlock = tadOffsets[indexes[idx]]; auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
auto zTadOffsetForBlock = zTadOffsets[idx]; auto zTadOffsetForBlock = zTadOffsets[idx];
auto rX = hX + xTadOffsetForBlock; auto rX = hX + xTadOffsetForBlock;
auto rZ = hZ + zTadOffsetForBlock; auto rZ = hZ + zTadOffsetForBlock;
if (xEWS == 1 && zEWS == 1) { if (xEWS == 1 && zEWS == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_SIMD for (int i = 0; i < tadLength; i++) {
for (int i = 0; i < tadLength; i++ ) { rZ[i] = rX[i];
rZ[i] = rX[i]; }
} } else if (xEWS >= 1 && zEWS >= 1) {
} else if (xEWS >= 1 && zEWS >= 1) { PRAGMA_OMP_SIMD
for (int i = 0; i < tadLength; i++) {
PRAGMA_OMP_SIMD rZ[i * zEWS] = rX[i * xEWS];
for (int i = 0; i < tadLength; i++ ) { }
rZ[i * zEWS] = rX[i * xEWS]; } else {
for (int i = 0; i < tadLength; i++) {
auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo);
auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo);
hZ[zOffset] = hX[xOffset];
}
} }
} }
else { };
for (int i = 0; i < tadLength; i++) {
auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo); samediff::Threads::parallel_tad(func, 0, n, 1, _threads);
auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo);
hZ[zOffset] = hX[xOffset];
}
}
}
} }
void pullRows(Nd4jPointer *extraPointers, void pullRows(Nd4jPointer *extraPointers,
@ -1433,30 +1431,29 @@ void tearGeneric(void *vx,
auto zEWS = shape::elementWiseStride(hZShapeInfo); auto zEWS = shape::elementWiseStride(hZShapeInfo);
auto numTads = shape::length(hXShapeInfo) / tadLength; auto numTads = shape::length(hXShapeInfo) / tadLength;
PRAGMA_OMP_PARALLEL_FOR auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong i = 0; i < numTads; i++) { for (auto i = start; i < stop; i += increment) {
auto hZ = reinterpret_cast<T *>(targets[i]); auto hZ = reinterpret_cast<T *>(targets[i]);
auto s = hX + tadOffsets[i]; auto s = hX + tadOffsets[i];
if (zEWS == 1 && tadEWS == 1) { if (zEWS == 1 && tadEWS == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_SIMD for (Nd4jLong j = 0; j < tadLength; j++) {
for (Nd4jLong j = 0; j < tadLength; j++) { hZ[j] = s[j];
hZ[j] = s[j]; }
} } else if (zEWS > 0 && tadEWS > 0) {
} else if (zEWS > 0 && tadEWS > 0) { PRAGMA_OMP_SIMD
for (Nd4jLong j = 0; j < tadLength; j++) {
PRAGMA_OMP_SIMD hZ[j * zEWS] = s[j * tadEWS];
for (Nd4jLong j = 0; j < tadLength; j++) { }
hZ[j * zEWS] = s[j * tadEWS]; } else {
for (Nd4jLong j = 0; j < tadLength; j++)
hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)];
} }
} }
else { };
for (Nd4jLong j = 0; j < tadLength; j++) samediff::Threads::parallel_tad(func,0, numTads);
hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)];
}
}
} }
void tear(Nd4jPointer *extraPointers, void tear(Nd4jPointer *extraPointers,
@ -1557,57 +1554,60 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
auto dX = reinterpret_cast<T **>(hX); auto dX = reinterpret_cast<T **>(hX);
auto dZ = reinterpret_cast<T **>(dz); auto dZ = reinterpret_cast<T **>(dz);
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(N) auto func = PRAGMA_THREADS_FOR {
for (int f = 0; f < N; f++) { for (auto f = start; f < stop; f += increment) {
auto hX = reinterpret_cast<T *>(dX[f]); auto hX = reinterpret_cast<T *>(dX[f]);
//auto hZ = reinterpret_cast<T *>(dZ[f]); //auto hZ = reinterpret_cast<T *>(dZ[f]);
auto xShapeInfo = hXShapeInfo[f]; auto xShapeInfo = hXShapeInfo[f];
auto tadOffset = reinterpret_cast<Nd4jLong *>(tadOffsets[f]); auto tadOffset = reinterpret_cast<Nd4jLong *>(tadOffsets[f]);
const auto tadLength = shape::length(tadOnlyShapeInfo[f]); const auto tadLength = shape::length(tadOnlyShapeInfo[f]);
auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]); auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]);
auto tadRank = shape::rank(tadOnlyShapeInfo[f]); auto tadRank = shape::rank(tadOnlyShapeInfo[f]);
auto numTads = shape::length(hXShapeInfo[f]) / tadLength; auto numTads = shape::length(hXShapeInfo[f]) / tadLength;
auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]); auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]);
auto tadStride = shape::stride(tadOnlyShapeInfo[f]); auto tadStride = shape::stride(tadOnlyShapeInfo[f]);
if (shape::rank(xShapeInfo) == 1) { if (shape::rank(xShapeInfo) == 1) {
auto xLength = shape::length(xShapeInfo); auto xLength = shape::length(xShapeInfo);
auto ews = shape::elementWiseStride(xShapeInfo); auto ews = shape::elementWiseStride(xShapeInfo);
for (Nd4jLong r = 0; r < xLength; r++) { for (Nd4jLong r = 0; r < xLength; r++) {
auto swapIdx = shuffleMap[r]; auto swapIdx = shuffleMap[r];
if (swapIdx < 0) if (swapIdx < 0)
continue; continue;
nd4j::math::nd4j_swap<T>(hX[r*ews], hX[swapIdx*ews]); nd4j::math::nd4j_swap<T>(hX[r * ews], hX[swapIdx * ews]);
} }
} else { } else {
for (Nd4jLong r = 0; r < numTads; r++) { for (Nd4jLong r = 0; r < numTads; r++) {
if (shuffleMap[r] < 0) if (shuffleMap[r] < 0)
continue; continue;
auto oldOffset = tadOffset[r]; auto oldOffset = tadOffset[r];
auto newOffset = tadOffset[shuffleMap[r]]; auto newOffset = tadOffset[shuffleMap[r]];
auto rX = hX + oldOffset; auto rX = hX + oldOffset;
auto rY = hX + newOffset; auto rY = hX + newOffset;
if (tadEWS == 1) { if (tadEWS == 1) {
for (Nd4jLong i = 0; i < tadLength; i++) { for (Nd4jLong i = 0; i < tadLength; i++) {
nd4j::math::nd4j_swap<T>(rX[i], rY[i]); nd4j::math::nd4j_swap<T>(rX[i], rY[i]);
} }
} else { } else {
for (Nd4jLong i = 0; i < tadLength; i++) { for (Nd4jLong i = 0; i < tadLength; i++) {
auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]); auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]);
nd4j::math::nd4j_swap<T>(hX[offset + oldOffset], hX[offset + newOffset]); nd4j::math::nd4j_swap<T>(hX[offset + oldOffset], hX[offset + newOffset]);
}
} }
} }
} }
} }
} };
samediff::Threads::parallel_tad(func, 0, N);
} }
void shuffle(Nd4jPointer *extras, void shuffle(Nd4jPointer *extras,
@ -1772,72 +1772,9 @@ void execAggregate(Nd4jPointer *extraPointers,int opNum,
void *realArguments, void *realArguments,
int numRealArguments, int numRealArguments,
nd4j::DataType dtype) { nd4j::DataType dtype) {
try {
BUILD_SINGLE_SELECTOR(dtype, NativeOpExecutioner::execAggregate, (nullptr, opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), FLOAT_TYPES);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
}
} }
template <typename T>
void _batchExecutor(Nd4jPointer *extraPointers,
int numAggregates,
int opNum,
int maxArgs,
int maxShapes,
int maxIntArrays,
int maxIntArraySize,
int maxIdx,
int maxReals,
void *ptrToArguments,
nd4j::DataType dtype) {
// probably, we don't want too much threads as usually
int _threads = nd4j::math::nd4j_min<int>(numAggregates, omp_get_max_threads());
nd4j::PointersHelper<T> helper(ptrToArguments,
numAggregates,
maxArgs,
maxShapes,
maxIntArrays,
maxIntArraySize,
maxIdx,
maxReals);
// special case here, we prefer spread arrangement here, all threads are detached from each other
PRAGMA_OMP_PARALLEL_FOR_THREADS(_threads)
for (int i = 0; i < numAggregates; i++) {
auto intArrays = new int *[maxIntArrays];
auto arguments = helper.getArguments(i);
auto shapes = helper.getShapeArguments(i);
auto idxArg = helper.getIndexArguments(i);
auto realArg = helper.getRealArguments(i);
for (int e = 0; e < maxIntArrays; e++) {
intArrays[e] = helper.getIntArrayArguments(i, e);
}
execAggregate(extraPointers,
opNum,
reinterpret_cast<void **>(arguments),
helper.getNumArguments(i),
shapes,
helper.getNumShapeArguments(i),
idxArg,
helper.getNumIndexArguments(i),
intArrays,
helper.getNumIntArrayArguments(i),
realArg,
helper.getNumRealArguments(i),
dtype);
delete [] intArrays;
}
}
BUILD_SINGLE_TEMPLATE(template void _batchExecutor, (Nd4jPointer *extraPointers, int numAggregates, int opNum, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments, nd4j::DataType dtype), FLOAT_TYPES);
void batchExecutor(Nd4jPointer *extraPointers, void batchExecutor(Nd4jPointer *extraPointers,
int numAggregates, int numAggregates,
int opNum, int opNum,
@ -1849,12 +1786,7 @@ void batchExecutor(Nd4jPointer *extraPointers,
int maxReals, int maxReals,
void *ptrToArguments, void *ptrToArguments,
nd4j::DataType dtype) { nd4j::DataType dtype) {
try {
BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
}
} }
void execAggregateBatch(Nd4jPointer *extraPointers, void execAggregateBatch(Nd4jPointer *extraPointers,
@ -1868,12 +1800,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
int maxReals, int maxReals,
void *ptrToArguments, void *ptrToArguments,
nd4j::DataType dtype) { nd4j::DataType dtype) {
try {
BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
}
} }
@ -2094,27 +2021,21 @@ const char* getAllCustomOps() {
template <typename T> template <typename T>
FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer hX, int N, T threshold) { FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer hX, int N, T threshold) {
auto buffer = reinterpret_cast<T *>(hX); auto buffer = reinterpret_cast<T *>(hX);
int span = (N / 6) + 8; int span = (N / 6) + 8;
int cnt = 0;
PRAGMA_OMP_PARALLEL_REDUCTION(+:cnt)
{
int tid = omp_get_thread_num();
int start = span * tid;
int stop = span * (tid + 1);
if (stop > N)
stop = N;
auto func = PRAGMA_REDUCE_LONG {
int64_t cnt = 0;
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int e = start; e < stop; e++) { for (auto e = start; e < stop; e++) {
auto v = nd4j::math::nd4j_abs<T>(buffer[e]); auto v = nd4j::math::nd4j_abs<T>(buffer[e]);
if (v >= threshold) if (v >= threshold)
cnt++; cnt++;
} }
}
return cnt; return cnt;
};
return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
} }
@ -2776,58 +2697,51 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
void* vIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) { void* vIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) {
auto hIindexes = reinterpret_cast<I*>(vIindexes); auto hIindexes = reinterpret_cast<I*>(vIindexes);
auto func = PRAGMA_THREADS_DO {
int numThreads = omp_get_max_threads(); for (int i = 0; i < numOfSubArrs; ++i) {
int threadIndex = thread_id;
PRAGMA_OMP_PARALLEL_THREADS(numThreads)
{
for (int i = 0; i < numOfSubArrs; ++i) {
int threadIndex = omp_get_thread_num();
const auto xIndex = hIindexes[i]; const auto xIndex = hIindexes[i];
const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads; const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads;
if (!isOwner) if (!isOwner)
continue; continue;
NDArray inSubArr( NDArray inSubArr(reinterpret_cast<int8_t *>(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), hXShapeInfo);
reinterpret_cast<int8_t *>(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), NDArray updSubArr(reinterpret_cast<int8_t *>(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), hYShapeInfo);
hXShapeInfo);
NDArray updSubArr(reinterpret_cast<int8_t *>(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)),
hYShapeInfo);
if (inSubArr.lengthOf() != updSubArr.lengthOf()) { if (inSubArr.lengthOf() != updSubArr.lengthOf()) {
continue; continue;
} }
switch (opCode) { switch (opCode) {
case 0: case 0:
inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr);
break; break;
case 1: case 1:
inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr);
break; break;
case 2: case 2:
inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr);
break; break;
case 3: case 3:
inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr);
break; break;
case 4: case 4:
inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr);
break; break;
case 5: case 5:
inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr);
break; break;
case 6: case 6:
inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr);
break; break;
default: default:
continue; continue;
}
} }
} };
}
samediff::Threads::parallel_do(func);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -2847,6 +2761,7 @@ void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
} }
} }
void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) {
try { try {
auto p = reinterpret_cast<nd4j::DebugInfo *>(debugInfo); auto p = reinterpret_cast<nd4j::DebugInfo *>(debugInfo);

View File

@ -25,6 +25,7 @@
#include <loops/transform_any.h> #include <loops/transform_any.h>
#include <loops/reduce_bool.h> #include <loops/reduce_bool.h>
#include <loops/reduce_long.h> #include <loops/reduce_long.h>
#include <loops/scalar.h>
#include <helpers/threshold.h> #include <helpers/threshold.h>
#include <ops/specials_cuda.h> #include <ops/specials_cuda.h>
#include <helpers/DebugHelper.h> #include <helpers/DebugHelper.h>
@ -33,8 +34,8 @@
#include <exceptions/datatype_exception.h> #include <exceptions/datatype_exception.h>
#include <exceptions/cuda_exception.h> #include <exceptions/cuda_exception.h>
#include <helpers/CudaLaunchHelper.h> #include <helpers/CudaLaunchHelper.h>
// FIXME: we need cuda-specific implementations
#include <GraphExecutioner.h> #include <GraphExecutioner.h>
#include <helpers/BlasHelper.h>
#include <graph/GraphHolder.h> #include <graph/GraphHolder.h>
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <PointersManager.h> #include <PointersManager.h>
@ -1723,11 +1724,7 @@ void execScalarTad(Nd4jPointer *extraPointers,
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ,
dZShapeInfo, dScalars, extraParams, dimension,
dimensionLength, tadShapeInfo, tadOffsets,
tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
#endif #endif
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
@ -1750,23 +1747,7 @@ void execAggregate(Nd4jPointer *extraPointers,
void *realArguments, void *realArguments,
int numRealArguments, int numRealArguments,
nd4j::DataType dtype) { nd4j::DataType dtype) {
try {
cudaStream_t *stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
int numBlocks = getDeviceId(extraPointers[2]);
int numThreads = getDeviceId(extraPointers[3]);
int shmem = getDeviceId(extraPointers[4]);
dim3 launchDims = dim3(numBlocks, numThreads, shmem);
BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction,
::aggregateKernelGeneric(launchDims, stream, opNum, arguments, numArguments, shapes,
numShapes, indexArguments, numIndexArguments, intArrays,
numIntArrays, realArguments, numRealArguments), FLOAT_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed");
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
}
} }
void batchExecutor(Nd4jPointer *extraPointers, void batchExecutor(Nd4jPointer *extraPointers,
@ -1788,25 +1769,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
int maxIntArrays, int maxIntArraySize, int maxIntArrays, int maxIntArraySize,
int maxIdx, int maxReals, int maxIdx, int maxReals,
void *ptrToArguments, nd4j::DataType dtype) { void *ptrToArguments, nd4j::DataType dtype) {
try {
// not implemented yet
cudaStream_t *stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
int numBlocks = getDeviceId(extraPointers[2]);
int numThreads = getDeviceId(extraPointers[3]);
int shmem = getDeviceId(extraPointers[4]);
dim3 launchDims = dim3(numAggregates, numThreads, shmem);
BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction,
::aggregateBatchKernelGeneric(launchDims, stream, opNum, numAggregates, maxArgs,
maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals,
ptrToArguments), FLOAT_TYPES);
DEBUG_KERNEL(stream, opNum);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
}
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////

View File

@ -53,6 +53,7 @@ CLEAN="false"
MINIFIER="false" MINIFIER="false"
TESTS="false" TESTS="false"
VERBOSE="false" VERBOSE="false"
VERBOSE_ARG="VERBOSE=1"
HELPER= HELPER=
NAME= NAME=
while [[ $# > 0 ]] while [[ $# > 0 ]]
@ -291,38 +292,37 @@ case "$OS" in
macosx*) macosx*)
# Do something under Mac OS X platform # Do something under Mac OS X platform
if [ "$CHIP" == "cuda" ]; then #if [ "$CHIP" == "cuda" ]; then
export CC=clang export CC=clang
export CXX=clang++ export CXX=clang++
PARALLEL="false"
else
export CC="$(ls -1 /usr/local/bin/gcc-? | head -n 1)"
export CXX="$(ls -1 /usr/local/bin/g++-? | head -n 1)"
PARALLEL="true" PARALLEL="true"
fi #else
# export CC="$(ls -1 /usr/local/bin/gcc-? | head -n 1)"
# export CXX="$(ls -1 /usr/local/bin/g++-? | head -n 1)"
# PARALLEL="true"
#fi
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DAPPLE_BUILD=true" export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DAPPLE_BUILD=true"
;; ;;
windows*) windows*)
# Do something under Windows NT platform # Do something under Windows NT platform
if [ "$CHIP" == "cuda" ]; then if [ "$CHIP" == "cuda" ]; then
export CMAKE_COMMAND="cmake -G \"Ninja\"" export CMAKE_COMMAND="cmake -G \"Ninja\""
export MAKE_COMMAND="ninja" export MAKE_COMMAND="ninja"
export CC="cl.exe" export CC="cl.exe"
export CXX="cl.exe" export CXX="cl.exe"
PARALLEL="true" PARALLEL="true"
else VERBOSE_ARG="-v"
else
export CMAKE_COMMAND="cmake -G \"MSYS Makefiles\"" export CMAKE_COMMAND="cmake -G \"MSYS Makefiles\""
export MAKE_COMMAND="make" export MAKE_COMMAND="make"
# Sam, do we really need this?
export CC=/mingw64/bin/gcc export CC=/mingw64/bin/gcc
export CXX=/mingw64/bin/g++ export CXX=/mingw64/bin/g++
PARALLEL="true" PARALLEL="true"
fi
fi # Try some defaults for Visual Studio 2013 if user has not run vcvarsall.bat or something
# Try some defaults for Visual Studio 2013 if user has not run vcvarsall.bat or something if [ -z "${VCINSTALLDIR:-}" ]; then
if [ -z "${VCINSTALLDIR:-}" ]; then
export VisualStudioVersion=12.0 export VisualStudioVersion=12.0
export VSINSTALLDIR="C:\\Program Files (x86)\\Microsoft Visual Studio $VisualStudioVersion" export VSINSTALLDIR="C:\\Program Files (x86)\\Microsoft Visual Studio $VisualStudioVersion"
export VCINSTALLDIR="$VSINSTALLDIR\\VC" export VCINSTALLDIR="$VSINSTALLDIR\\VC"
@ -332,10 +332,10 @@ case "$OS" in
export LIB="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\lib\\winv6.3\\um\\x64" export LIB="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\lib\\winv6.3\\um\\x64"
export LIBPATH="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\References\\CommonConfiguration\\Neutral" export LIBPATH="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\References\\CommonConfiguration\\Neutral"
export PATH="$PATH:$VCINSTALLDIR\\BIN\\amd64:$WindowsSdkDir\\bin\\x64:$WindowsSdkDir\\bin\\x86" export PATH="$PATH:$VCINSTALLDIR\\BIN\\amd64:$WindowsSdkDir\\bin\\x64:$WindowsSdkDir\\bin\\x86"
fi fi
# Make sure we are using 64-bit MinGW-w64 # Make sure we are using 64-bit MinGW-w64
export PATH=/mingw64/bin/:$PATH export PATH=/mingw64/bin/:/mingw64/lib:$PATH
# export GENERATOR="MSYS Makefiles" # export GENERATOR="MSYS Makefiles"
;; ;;
esac esac
@ -534,6 +534,6 @@ if [ "$PARALLEL" == "true" ]; then
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ" MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
fi fi
if [ "$VERBOSE" == "true" ]; then if [ "$VERBOSE" == "true" ]; then
MAKE_ARGUMENTS="$MAKE_ARGUMENTS VERBOSE=1" MAKE_ARGUMENTS="$MAKE_ARGUMENTS $VERBOSE_ARG"
fi fi
eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../.. eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../..

View File

@ -29,6 +29,7 @@
#include <helpers/BitwiseUtils.h> #include <helpers/BitwiseUtils.h>
#include <loops/type_conversions.h> #include <loops/type_conversions.h>
#include <dll.h> #include <dll.h>
#include <execution/Threads.h>
namespace nd4j { namespace nd4j {
template <typename T> template <typename T>
@ -50,9 +51,12 @@ namespace nd4j {
else else
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer); TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
#else #else
PRAGMA_OMP_PARALLEL_FOR_SIMD auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < length; e++) for (auto e = start; e < stop; e += increment)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
@ -105,9 +109,12 @@ namespace nd4j {
else else
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer); TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
#else #else
PRAGMA_OMP_PARALLEL_FOR_SIMD auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < length; e++) for (auto e = start; e < stop; e += increment)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
@ -130,9 +137,12 @@ namespace nd4j {
#else #else
PRAGMA_OMP_PARALLEL_FOR auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < length; e++) for (auto e = start; e < stop; e += increment)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
} }
@ -153,9 +163,12 @@ namespace nd4j {
else else
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer); TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
#else #else
PRAGMA_OMP_PARALLEL_FOR auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong e = 0; e < length; e++) for (auto e = start; e < stop; e += increment)
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e])); buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
};
samediff::Threads::parallel_for(func, 0, length);
#endif #endif
delete[] tmp; delete[] tmp;
} }

View File

@ -26,6 +26,7 @@
#ifdef __CUDACC__ #ifdef __CUDACC__
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <helpers/DebugHelper.h>
#endif #endif
#include <dll.h> #include <dll.h>

View File

@ -97,10 +97,10 @@ namespace cnpy {
* @param t * @param t
* @return * @return
*/ */
char mapType(const std::type_info &t); ND4J_EXPORT char mapType(const std::type_info &t);
template <typename T> template <typename T>
char mapType(); ND4J_EXPORT char mapType();
/** /**
* *
@ -111,7 +111,7 @@ namespace cnpy {
* @return * @return
*/ */
template<typename T> template<typename T>
std::vector<char> createNpyHeader(const void *data, ND4J_EXPORT std::vector<char> createNpyHeader(const void *data,
const unsigned int *shape, const unsigned int *shape,
const unsigned int ndims, const unsigned int ndims,
unsigned int wordSize = 4); unsigned int wordSize = 4);
@ -126,7 +126,7 @@ namespace cnpy {
* @param ndims * @param ndims
* @param fortranOrder * @param fortranOrder
*/ */
void parseNpyHeader(FILE *fp, ND4J_EXPORT void parseNpyHeader(FILE *fp,
unsigned int &wordSize, unsigned int &wordSize,
unsigned int *&shape, unsigned int *&shape,
unsigned int &ndims, unsigned int &ndims,
@ -143,7 +143,7 @@ namespace cnpy {
* @param ndims * @param ndims
* @param fortran_order * @param fortran_order
*/ */
void parseNpyHeaderPointer( ND4J_EXPORT void parseNpyHeaderPointer(
const char *header, const char *header,
unsigned int& word_size, unsigned int& word_size,
unsigned int*& shape, unsigned int*& shape,
@ -156,7 +156,7 @@ namespace cnpy {
* @param global_header_size * @param global_header_size
* @param global_header_offset * @param global_header_offset
*/ */
void parseZipFooter(FILE *fp, ND4J_EXPORT void parseZipFooter(FILE *fp,
unsigned short &nrecs, unsigned short &nrecs,
unsigned int &global_header_size, unsigned int &global_header_size,
unsigned int &global_header_offset); unsigned int &global_header_offset);
@ -167,14 +167,14 @@ namespace cnpy {
* @param varname * @param varname
* @return * @return
*/ */
NpyArray npzLoad(std::string fname, std::string varname); ND4J_EXPORT NpyArray npzLoad(std::string fname, std::string varname);
/** /**
* *
* @param fname * @param fname
* @return * @return
*/ */
NpyArray npyLoad(std::string fname); ND4J_EXPORT NpyArray npyLoad(std::string fname);
/** /**
* Parse the numpy header from * Parse the numpy header from
@ -187,7 +187,7 @@ namespace cnpy {
* @param ndims * @param ndims
* @param fortranOrder * @param fortranOrder
*/ */
void parseNpyHeaderStr(std::string header, ND4J_EXPORT void parseNpyHeaderStr(std::string header,
unsigned int &wordSize, unsigned int &wordSize,
unsigned int *&shape, unsigned int *&shape,
unsigned int &ndims, unsigned int &ndims,
@ -199,14 +199,14 @@ namespace cnpy {
* @param fp * @param fp
* @return * @return
*/ */
int * shapeFromFile(FILE *fp); ND4J_EXPORT int* shapeFromFile(FILE *fp);
/** /**
* *
* @param data * @param data
* @return * @return
*/ */
int * shapeFromPointer(char *data); ND4J_EXPORT int* shapeFromPointer(char *data);
/** /**
* Load the numpy array from the given file. * Load the numpy array from the given file.
@ -250,7 +250,7 @@ namespace cnpy {
* @param ndims * @param ndims
* @param fortran_order * @param fortran_order
*/ */
void parseNpyHeader(std::string header, ND4J_EXPORT void parseNpyHeader(std::string header,
unsigned int &word_size, unsigned int &word_size,
unsigned int *&shape, unsigned int *&shape,
unsigned int &ndims, unsigned int &ndims,
@ -273,7 +273,7 @@ namespace cnpy {
template<typename T> template<typename T>
void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w"); ND4J_EXPORT void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w");
} }
@ -284,8 +284,8 @@ namespace cnpy {
* @param rhs * @param rhs
* @return * @return
*/ */
template<typename T> template<typename T>
std::vector<char>& operator+=(std::vector<char>& lhs, const T rhs); ND4J_EXPORT std::vector<char>& operator+=(std::vector<char>& lhs, const T rhs);
#endif #endif

View File

@ -20,6 +20,9 @@
#ifndef NATIVEOPERATIONS_DLL_H #ifndef NATIVEOPERATIONS_DLL_H
#define NATIVEOPERATIONS_DLL_H #define NATIVEOPERATIONS_DLL_H
#include <msvc.h>
#ifdef _WIN32 #ifdef _WIN32
//#include <windows.h> //#include <windows.h>
# define ND4J_EXPORT __declspec(dllexport) # define ND4J_EXPORT __declspec(dllexport)

View File

@ -0,0 +1,52 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_BLOCKINGQUEUE_H
#define SAMEDIFF_BLOCKINGQUEUE_H
#include <functional>
#include <queue>
#include <mutex>
#include <atomic>
#include <condition_variable>
namespace samediff {
template <typename T>
class BlockingQueue {
private:
std::queue<T> _queue;
std::mutex _lock;
std::atomic<int> _size;
std::atomic<bool> _available;
std::condition_variable _condition;
public:
BlockingQueue(int queueSize);
~BlockingQueue() = default;
T poll();
void put(const T &t);
bool available();
void markAvailable();
void markUnavailable();
};
}
#endif //DEV_TESTS_BLOCKINGQUEUE_H

View File

@ -0,0 +1,94 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_CALLABLEINTERFACE_H
#define SAMEDIFF_CALLABLEINTERFACE_H
#include <openmp_pragmas.h>
#include <cstdint>
#include <functional>
#include <atomic>
#include <array>
#include <mutex>
#include <condition_variable>
namespace samediff {
/**
* This class is suited for passing functions to execution threads without queues
*/
class CallableInterface {
private:
// parallel_for functions
FUNC_1D _function_1d;
FUNC_2D _function_2d;
FUNC_3D _function_3d;
// parallel function
FUNC_DO _function_do;
// reduction functions
FUNC_RL _function_rl;
FUNC_RD _function_rd;
std::array<int64_t, 9> _arguments;
volatile int _branch = 0;
volatile uint32_t _thread_id = 0;
volatile uint32_t _num_threads = 0;
std::atomic<bool> _finished;
std::atomic<bool> _filled;
std::atomic<bool> _available;
std::condition_variable _starter;
std::condition_variable _finisher;
int64_t* _lptr = nullptr;
double* _dptr = nullptr;
std::mutex _ms;
std::mutex _mf;
public:
CallableInterface();
~CallableInterface() = default;
void waitForTask();
void waitForCompletion();
void fill(int thread_id, int num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x);
void fill(int thread_id, int num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x);
void fill(int thread_id, int num_threads, FUNC_DO func);
void fill(int thread_id, int num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x);
void fill(int thread_id, int num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
void fill(int thread_id, int num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z);
bool available();
void markAvailable();
void markUnavailable();
void finish();
void execute();
};
}
#endif //DEV_TESTS_CALLABLEINTERFACE_H

View File

@ -0,0 +1,92 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef DEV_TESTS_CALLABLEWITHARGUMENTS_H
#define DEV_TESTS_CALLABLEWITHARGUMENTS_H
#include <functional>
#include <vector>
#include <atomic>
#include <condition_variable>
#include <op_boilerplate.h>
namespace samediff {
class CallableWithArguments {
FUNC_DO _function_do;
FUNC_1D _function_1d;
FUNC_2D _function_2d;
FUNC_3D _function_3d;
std::vector<int64_t> _arguments;
std::atomic<bool> _finished;
std::condition_variable _condition;
std::mutex _lock;
int _dimensions = 0;
uint64_t _threadId;
uint64_t _numThreads;
public:
CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads);
CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x);
CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y);
CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z);
/**
* This method returns number of dimensions
* @return
*/
int dimensions();
/**
* This method checks if this callable is finished
* @return
*/
bool finished();
/**
* this method marks this Callable as finished
*/
void finish();
/**
* This method blocks until callable is finished
*/
void waitUntilFinished();
std::vector<int64_t>& arguments();
FUNC_DO function_do();
FUNC_1D function_1d();
FUNC_2D function_2d();
FUNC_3D function_3d();
uint64_t threadId();
uint64_t numThreads();
};
}
#endif //DEV_TESTS_CALLABLEWITHARGUMENTS_H

View File

@ -0,0 +1,71 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_THREADPOOL_H
#define SAMEDIFF_THREADPOOL_H
#include <list>
#include <vector>
#include <thread>
#include <atomic>
#include <mutex>
#include <execution/BlockingQueue.h>
#include <execution/CallableWithArguments.h>
#include <execution/CallableInterface.h>
#include <execution/Ticket.h>
#include <queue>
namespace samediff {
class ThreadPool {
private:
static ThreadPool* _INSTANCE;
std::vector<std::thread*> _threads;
std::vector<BlockingQueue<CallableWithArguments*>*> _queues;
std::vector<CallableInterface*> _interfaces;
std::mutex _lock;
std::atomic<int> _available;
std::queue<Ticket*> _tickets;
protected:
ThreadPool();
~ThreadPool();
public:
static ThreadPool* getInstance();
/**
* This method returns list of pointers to threads ONLY if num_threads of threads were available upon request, returning empty list otherwise
* @param num_threads
* @return
*/
Ticket* tryAcquire(int num_threads);
/**
* This method marks specified number of threads as released, and available for use
* @param num_threads
*/
void release(int num_threads = 1);
void release(Ticket *ticket);
};
}
#endif //DEV_TESTS_THREADPOOL_H

View File

@ -0,0 +1,160 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_THREADS_H
#define SAMEDIFF_THREADS_H
#include <functional>
#include <openmp_pragmas.h>
#include <op_boilerplate.h>
#include <Environment.h>
#include <op_enums.h>
namespace samediff {
class ThreadsHelper {
public:
static int numberOfThreads(int maxThreads, uint64_t numberOfElements);
static int numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y);
static int numberOfThreads3d(int maxThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z);
static int pickLoop2d(int numThreads, uint64_t iters_x, uint64_t iters_y);
static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z);
};
class Span {
private:
int64_t _startX, _stopX, _incX;
public:
Span(int64_t start_x, int64_t stop_x, int64_t inc_x);
~Span() = default;
int64_t startX() const;
int64_t stopX() const;
int64_t incX() const;
static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x);
};
class Span2 {
private:
int64_t _startX, _stopX, _incX;
int64_t _startY, _stopY, _incY;
public:
Span2(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
~Span2() = default;
int64_t startX() const;
int64_t startY() const;
int64_t stopX() const;
int64_t stopY() const;
int64_t incX() const;
int64_t incY() const;
static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
};
class Span3 {
private:
int64_t _startX, _stopX, _incX;
int64_t _startY, _stopY, _incY;
int64_t _startZ, _stopZ, _incZ;
public:
Span3(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z);
~Span3() = default;
int64_t startX() const;
int64_t startY() const;
int64_t startZ() const;
int64_t stopX() const;
int64_t stopY() const;
int64_t stopZ() const;
int64_t incX() const;
int64_t incY() const;
int64_t incZ() const;
static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z);
};
class Threads {
public:
/**
* This function executes 1 dimensional loop for a given number of threads
* PLEASE NOTE: this function can use smaller number of threads than requested.
*
* @param function
* @param numThreads
* @param start
* @param stop
* @param increment
* @return
*/
static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads());
static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads());
/**
*
* @param function
* @param numThreads
* @param start_x
* @param stop_x
* @param inc_x
* @param start_y
* @param stop_y
* @param inc_y
* @return
*/
static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads(), bool debug = false);
/**
*
* @param function
* @param numThreads
* @param start_x
* @param stop_x
* @param inc_x
* @param start_y
* @param stop_y
* @param inc_y
* @param start_z
* @param stop_z
* @param inc_z
* @return
*/
static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
/**
*
* @param function
* @param numThreads
* @return
*/
static int parallel_do(FUNC_DO function, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
};
}
#endif //SAMEDIFF_THREADS_H

View File

@ -0,0 +1,67 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_TICKET_H
#define SAMEDIFF_TICKET_H
#include <vector>
#include <execution/BlockingQueue.h>
#include <execution/CallableWithArguments.h>
#include <execution/CallableInterface.h>
#include <atomic>
#include <mutex>
namespace samediff {
class Ticket {
private:
bool _acquired = false;
std::vector<BlockingQueue<CallableWithArguments*>*> _queues;
std::vector<CallableWithArguments*> _callables;
std::vector<CallableInterface*> _interfaces;
uint32_t _acquiredThreads = 0;
public:
explicit Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues);
Ticket();
~Ticket() = default;
bool acquired();
void acquiredThreads(uint32_t threads);
void attach(uint32_t thread_id, CallableInterface *interface);
// deprecated one
void enqueue(int thread_id, CallableWithArguments* callable);
void enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x);
void enqueue(uint32_t thread_id, uint32_t num_threads, double *lpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x);
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func);
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x);
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_, int64_t stop_z, int64_t inc_z);
void waitAndRelease();
};
}
#endif //DEV_TESTS_TICKET_H

View File

@ -0,0 +1,73 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/BlockingQueue.h>
#include <CallableWithArguments.h>
#include <thread>
namespace samediff {
template <typename T>
BlockingQueue<T>::BlockingQueue(int queueSize) {
_size = 0;
_available = true;
}
template <typename T>
T BlockingQueue<T>::poll() {
// locking untill there's something within queue
std::unique_lock<std::mutex> lock(_lock);
_condition.wait(lock, [&]{ return this->_size.load() != 0; });
T t(std::move(_queue.front()));
_queue.pop();
_size--;
return t;
}
template <typename T>
void BlockingQueue<T>::put(const T &t) {
{
// locking before push, unlocking after
std::unique_lock<std::mutex> lock(_lock);
_queue.push(t);
_size++;
}
// notifying condition
_condition.notify_one();
}
template <typename T>
bool BlockingQueue<T>::available() {
return _available.load();
}
template <typename T>
void BlockingQueue<T>::markAvailable() {
_available = true;
}
template <typename T>
void BlockingQueue<T>::markUnavailable() {
_available = false;
}
template class BlockingQueue<CallableWithArguments*>;
}

View File

@ -0,0 +1,213 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/CallableInterface.h>
#include <helpers/logger.h>
namespace samediff {
CallableInterface::CallableInterface() {
// initial state is available
_available = true;
_filled = false;
_finished = false;
}
bool CallableInterface::available() {
return _available.load();
}
void CallableInterface::markUnavailable() {
_available = false;
}
void CallableInterface::markAvailable() {
_available = true;
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_DO func) {
_function_do = std::move(func);
_branch = 0;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_1D func, int64_t startX, int64_t stopX, int64_t incX) {
_function_1d = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_branch = 1;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_2D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y) {
_function_2d = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_arguments[3] = start_y;
_arguments[4] = stop_y;
_arguments[5] = inc_y;
_branch = 2;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, FUNC_3D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) {
_function_3d = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_arguments[3] = start_y;
_arguments[4] = stop_y;
_arguments[5] = inc_y;
_arguments[6] = start_z;
_arguments[7] = stop_z;
_arguments[8] = inc_z;
_branch = 3;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, int64_t *lptr, FUNC_RL func, int64_t startX, int64_t stopX, int64_t incX) {
_function_rl = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_lptr = lptr;
_branch = 4;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::fill(int threadID, int numThreads, double *dptr, FUNC_RD func, int64_t startX, int64_t stopX, int64_t incX) {
_function_rd = std::move(func);
_arguments[0] = startX;
_arguments[1] = stopX;
_arguments[2] = incX;
_dptr = dptr;
_branch = 5;
_num_threads = numThreads;
_thread_id = threadID;
_finished = false;
{
std::unique_lock<std::mutex> l(_ms);
_filled = true;
}
_starter.notify_one();
}
void CallableInterface::waitForTask() {
// block until task is available
std::unique_lock<std::mutex> lock(_ms);
_starter.wait(lock, [&]{ return _filled.load(); });
}
void CallableInterface::waitForCompletion() {
//while (!_finished.load());
// block until finished
std::unique_lock<std::mutex> lock(_mf);
_finisher.wait(lock, [&] { return _finished.load(); });
}
void CallableInterface::finish() {
// mark as finished
{
std::unique_lock<std::mutex> l(_mf);
_finished.store(true);
}
_finisher.notify_one();
}
void CallableInterface::execute() {
// mark it as consumed
_filled = false;
// actually executing op
switch (_branch) {
case 0:
_function_do(_thread_id, _num_threads);
break;
case 1:
_function_1d(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
break;
case 2:
_function_2d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5]);
break;
case 3:
_function_3d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5], _arguments[6], _arguments[7], _arguments[8]);
break;
case 4:
_lptr[0] = _function_rl(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
break;
case 5:
_dptr[0] = _function_rd(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
break;
}
// notify that thread finished the job
this->finish();
}
}

View File

@ -0,0 +1,103 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/CallableWithArguments.h>
namespace samediff {
CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) {
_function_do = func;
_finished = false;
_threadId = thread_id;
_numThreads = numThreads;
_dimensions = 0;
}
CallableWithArguments::CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z) {
_function_3d = func;
_arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y, start_z, stop_z, increment_z};
_finished = false;
_threadId = thread_id;
_dimensions = 3;
}
CallableWithArguments::CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x) {
_function_1d = func;
_arguments = {start_x, stop_x, increment_x};
_finished = false;
_threadId = thread_id;
_dimensions = 1;
}
CallableWithArguments::CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y) {
_function_2d = func;
_arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y};
_finished = false;
_threadId = thread_id;
_dimensions = 2;
}
int CallableWithArguments::dimensions() {
return _dimensions;
}
std::vector<int64_t>& CallableWithArguments::arguments() {
return _arguments;
}
bool CallableWithArguments::finished() {
return _finished.load();
}
void CallableWithArguments::finish() {
std::lock_guard<std::mutex> lock(_lock);
_finished = true;
_condition.notify_one();
}
void CallableWithArguments::waitUntilFinished() {
std::unique_lock<std::mutex> lock(_lock);
_condition.wait(lock, [&]{ return _finished.load(); });
}
FUNC_1D CallableWithArguments::function_1d() {
return _function_1d;
}
FUNC_2D CallableWithArguments::function_2d() {
return _function_2d;
}
FUNC_DO CallableWithArguments::function_do() {
return _function_do;
}
FUNC_3D CallableWithArguments::function_3d() {
return _function_3d;
}
uint64_t CallableWithArguments::threadId() {
return _threadId;
}
uint64_t CallableWithArguments::numThreads() {
return _numThreads;
}
}

View File

@ -0,0 +1,194 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/ThreadPool.h>
#include <stdexcept>
#include <helpers/logger.h>
#if defined(_WIN32) || defined(_WIN64)
//#include <windows.h>
#endif
namespace samediff {
// this function executed once per thread, it polls functions from queue, and executes them via wrapper
static void executionLoop_(int thread_id, BlockingQueue<CallableWithArguments*> *queue) {
while (true) {
// this method blocks until there's something within queue
auto c = queue->poll();
//nd4j_printf("ThreadPool: starting thread %i\n", c->threadId());
switch (c->dimensions()) {
case 0: {
c->function_do()(c->threadId(), c->numThreads());
c->finish();
}
break;
case 1: {
auto args = c->arguments();
c->function_1d()(c->threadId(), args[0], args[1], args[2]);
c->finish();
}
break;
case 2: {
auto args = c->arguments();
c->function_2d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5]);
c->finish();
//nd4j_printf("ThreadPool: finished thread %i\n", c->threadId());
}
break;
case 3: {
auto args = c->arguments();
c->function_3d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8]);
c->finish();
}
break;
default:
throw std::runtime_error("Don't know what to do with provided Callable");
}
}
}
static void executionLoopWithInterface_(int thread_id, CallableInterface *c) {
while (true) {
// blocking here until there's something to do
c->waitForTask();
// execute whatever we have
c->execute();
}
}
ThreadPool::ThreadPool() {
// TODO: number of threads must reflect number of cores for UMA system. In case of NUMA it should be per-device pool
// FIXME: on mobile phones this feature must NOT be used
_available = nd4j::Environment::getInstance()->maxThreads();
_queues.resize(_available.load());
_threads.resize(_available.load());
_interfaces.resize(_available.load());
// creating threads here
for (int e = 0; e < _available.load(); e++) {
_queues[e] = new BlockingQueue<CallableWithArguments*>(2);
_interfaces[e] = new CallableInterface();
_threads[e] = new std::thread(executionLoopWithInterface_, e, _interfaces[e]);
_tickets.push(new Ticket());
// _threads[e] = new std::thread(executionLoop_, e, _queues[e]);
// TODO: add other platforms here as well
// now we must set affinity, and it's going to be platform-specific thing
#ifdef LINUX_BUILD
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(e, &cpuset);
int rc = pthread_setaffinity_np(_threads[e]->native_handle(), sizeof(cpu_set_t), &cpuset);
if (rc != 0)
throw std::runtime_error("Failed to set pthread affinity");
#endif
/*
#if defined(_WIN32) || defined(_WIN64)
// we can't set affinity to more than 64 cores
if (e <= 64) {
auto mask = (static_cast<DWORD_PTR>(1) << e);
auto result = SetThreadAffinityMask(_threads[e]->native_handle(), mask);
if (!result)
throw std::runtime_error("Failed to set pthread affinity");
}
// that's fine. no need for time_critical here
SetThreadPriority(_threads[e]->native_handle(), THREAD_PRIORITY_HIGHEST);
#endif
*/
}
}
ThreadPool::~ThreadPool() {
// TODO: implement this one properly
for (int e = 0; e < _queues.size(); e++) {
// stop each and every thread
// release queue and thread
//delete _queues[e];
//delete _threads[e];
}
}
static std::mutex _lmutex;
ThreadPool* ThreadPool::getInstance() {
std::unique_lock<std::mutex> lock(_lmutex);
if (!_INSTANCE)
_INSTANCE = new ThreadPool();
return _INSTANCE;
}
void ThreadPool::release(int numThreads) {
_available += numThreads;
}
Ticket* ThreadPool::tryAcquire(int numThreads) {
//std::vector<BlockingQueue<CallableWithArguments*>*> queues;
Ticket *t = nullptr;
// we check for threads availability first
bool threaded = false;
{
// we lock before checking availability
std::unique_lock<std::mutex> lock(_lock);
if (_available >= numThreads) {
threaded = true;
_available -= numThreads;
// getting a ticket from the queue
t = _tickets.front();
_tickets.pop();
// ticket must contain information about number of threads for the current session
t->acquiredThreads(numThreads);
// filling ticket with executable interfaces
for (int e = 0, i = 0; e < _queues.size() && i < numThreads; e++) {
if (_interfaces[e]->available()) {
t->attach(i++, _interfaces[e]);
_interfaces[e]->markUnavailable();
}
}
}
}
// we either dispatch tasks to threads, or run single-threaded
if (threaded) {
return t;
} else {
// if there's no threads available - return nullptr
return nullptr;
}
}
void ThreadPool::release(samediff::Ticket *ticket) {
// returning ticket back to the queue
std::unique_lock<std::mutex> lock(_lock);
_tickets.push(ticket);
}
ThreadPool* ThreadPool::_INSTANCE = 0;
}

View File

@ -0,0 +1,641 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <vector>
#include <thread>
#include <helpers/logger.h>
#include <templatemath.h>
#include <shape.h>
namespace samediff {
int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) {
// let's see how many threads we actually need first
auto optimalThreads = nd4j::math::nd4j_max<uint64_t>(1, numberOfElements / 1024);
// now return the smallest value
return nd4j::math::nd4j_min<int>(optimalThreads, maxThreads);
}
Span3::Span3(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) {
_startX = startX;
_startY = startY;
_startZ = startZ;
_stopX = stopX;
_stopY = stopY;
_stopZ = stopZ;
_incX = incX;
_incY = incY;
_incZ = incZ;
}
Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) {
switch (loop) {
case 1: {
auto span = (stopX - startX) / numThreads;
auto s = span * threadID;
auto e = s + span;
if (threadID == numThreads - 1)
e = stopX;
return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ);
}
break;
case 2: {
auto span = (stopY - startY) / numThreads;
auto s = span * threadID;
auto e = s + span;
if (threadID == numThreads - 1)
e = stopY;
return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ);
}
break;
case 3: {
auto span = (stopZ - startZ) / numThreads;
auto s = span * threadID;
auto e = s + span;
if (threadID == numThreads - 1)
e = stopZ;
return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ);
}
break;
default:
throw std::runtime_error("");
}
return Span3(startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
}
Span::Span(int64_t startX, int64_t stopX, int64_t incX) {
_startX = startX;
_stopX = stopX;
_incX = incX;
}
Span Span::build(uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX) {
auto span = (stopX - startX) / numThreads;
auto s = span * threadID;
auto e = s + span;
if (threadID == numThreads - 1)
e = stopX;
return Span(s, e, incX);
}
Span2::Span2(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) {
_startX = startX;
_startY = startY;
_stopX = stopX;
_stopY = stopY;
_incX = incX;
_incY = incY;
}
Span2 Span2::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) {
switch (loop) {
case 1: {
auto span = (stopX - startX) / numThreads;
auto s = span * threadID;
auto e = s + span;
if (threadID == numThreads - 1)
e = stopX;
return Span2(s, e, incX, startY, stopY, incY);
}
break;
case 2: {
auto span = (stopY - startY) / numThreads;
auto s = span * threadID;
auto e = s + span;
if (threadID == numThreads - 1)
e = stopY;
return Span2(startX, stopX, incX, s, e, incY);
}
break;
default:
throw std::runtime_error("");
}
}
int64_t Span::startX() const {
return _startX;
}
int64_t Span::stopX() const {
return _stopX;
}
int64_t Span::incX() const {
return _incX;
}
int64_t Span2::startX() const {
return _startX;
}
int64_t Span2::startY() const {
return _startY;
}
int64_t Span2::stopX() const {
return _stopX;
}
int64_t Span2::stopY() const {
return _stopY;
}
int64_t Span2::incX() const {
return _incX;
}
int64_t Span2::incY() const {
return _incY;
}
int64_t Span3::startX() const {
return _startX;
}
int64_t Span3::startY() const {
return _startY;
}
int64_t Span3::startZ() const {
return _startZ;
}
int64_t Span3::stopX() const {
return _stopX;
}
int64_t Span3::stopY() const {
return _stopY;
}
int64_t Span3::stopZ() const {
return _stopZ;
}
int64_t Span3::incX() const {
return _incX;
}
int64_t Span3::incY() const {
return _incY;
}
int64_t Span3::incZ() const {
return _incZ;
}
int ThreadsHelper::pickLoop2d(int numThreads, uint64_t itersX, uint64_t itersY) {
// if one of dimensions is definitely too small - we just pick the other one
if (itersX < numThreads && itersY >= numThreads)
return 2;
if (itersY < numThreads && itersX >= numThreads)
return 1;
// next step - we pick the most balanced dimension
auto remX = itersX % numThreads;
auto remY = itersY % numThreads;
auto splitY = itersY / numThreads;
// if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution
if (remX == 0)
return 1;
if (remY == 0)
return 2;
// if there's no loop without a remainder - we're picking one with smaller remainder
if (remX < remY)
return 1;
if (remY < remX && splitY >= 64) // we don't want too small splits over last dimension, or vectorization will fail
return 2;
// if loops are equally sized - give the preference to the first thread
return 1;
}
static int threads_(int maxThreads, uint64_t elements) {
if (elements == maxThreads) {
return maxThreads;
}
else if (elements > maxThreads) {
// if we have full load across thread, or at least half of threads can be utilized
auto rem = elements % maxThreads;
if (rem == 0 || rem >= maxThreads / 3)
return maxThreads;
else
return threads_(maxThreads - 1, elements);
}
else if (elements < maxThreads) {
return elements;
}
return 1;
}
int ThreadsHelper::numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y) {
// in some cases there's nothing to think about, part 1
if (iters_x < maxThreads && iters_y < maxThreads)
return nd4j::math::nd4j_max<int>(iters_x, iters_y);
auto remX = iters_x % maxThreads;
auto remY = iters_y % maxThreads;
// in some cases there's nothing to think about, part 2
if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0))
return maxThreads;
// at this point we suppose that there's no loop perfectly matches number of our threads
// so let's pick something as equal as possible
if (iters_x > maxThreads || iters_y > maxThreads)
return maxThreads;
else
return numberOfThreads2d(maxThreads - 1, iters_x, iters_y);
}
int ThreadsHelper::numberOfThreads3d(int maxThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) {
// we don't want to run underloaded threads
if (itersX * itersY * itersZ <= 32)
return 1;
auto remX = itersX % maxThreads;
auto remY = itersY % maxThreads;
auto remZ = itersZ % maxThreads;
// if we have perfect balance across one of dimensions - just go for it
if ((itersX >= maxThreads && remX == 0) || (itersY >= maxThreads && remY == 0) || (itersZ >= maxThreads && remZ == 0))
return maxThreads;
int threadsX = 0, threadsY = 0, threadsZ = 0;
// now we look into possible number of
threadsX = threads_(maxThreads, itersX);
threadsY = threads_(maxThreads, itersY);
threadsZ = threads_(maxThreads, itersZ);
// we want to split as close to outer loop as possible, so checking it out first
if (threadsX >= threadsY && threadsX >= threadsZ)
return threadsX;
else if (threadsY >= threadsX && threadsY >= threadsZ)
return threadsY;
else if (threadsZ >= threadsX && threadsZ >= threadsY)
return threadsZ;
return 1;
}
int ThreadsHelper::pickLoop3d(int numThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) {
auto remX = itersX % numThreads;
auto remY = itersY % numThreads;
auto remZ = itersZ % numThreads;
auto splitX = itersX / numThreads;
auto splitY = itersY / numThreads;
auto splitZ = itersZ / numThreads;
// if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution
if (remX == 0)
return 1;
else if (remY == 0)
return 2;
else if (remZ == 0) // TODO: we don't want too smal splits over last dimension? or we do?
return 3;
if (itersX > numThreads)
return 1;
else if (itersY > numThreads)
return 2;
else if (itersZ > numThreads)
return 3;
return 1;
}
int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
if (start > stop)
throw std::runtime_error("Threads::parallel_for got start > stop");
auto delta = (stop - start);
if (numThreads > delta)
numThreads = delta;
if (numThreads == 0)
return 0;
// shortcut
if (numThreads == 1) {
function(0, start, stop, increment);
return 1;
}
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
// if we got our threads - we'll run our jobs here
auto span = delta / numThreads;
for (uint32_t e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = start_ + span;
// last thread will process tail
if (e == numThreads - 1)
stop_ = stop;
// putting the task into the queue for a given thread
ticket->enqueue(e, numThreads, function, start_, stop_, increment);
}
// block and wait till all threads finished the job
ticket->waitAndRelease();
// we tell that parallelism request succeeded
return numThreads;
} else {
// if there were no threads available - we'll execute function right within current thread
function(0, start, stop, increment);
// we tell that parallelism request declined
return 1;
}
}
int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
if (start > stop)
throw std::runtime_error("Threads::parallel_for got start > stop");
auto delta = (stop - start);
// in some cases we just fire func as is
if (delta == 0 || numThreads == 1) {
function(0, start, stop, increment);
return 1;
}
auto numElements = delta / increment;
// we decide what's optimal number of threads we need here, and execute it in parallel_tad.
numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements);
return parallel_tad(function, start, stop, increment, numThreads);
}
int Threads::parallel_for(FUNC_2D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, uint64_t numThreads, bool debug) {
if (startX > stopX)
throw std::runtime_error("Threads::parallel_for got startX > stopX");
if (startY > stopY)
throw std::runtime_error("Threads::parallel_for got startY > stopY");
// number of elements per loop
auto delta_x = (stopX - startX);
auto delta_y = (stopY - startY);
// number of iterations per loop
auto itersX = delta_x / incX;
auto itersY = delta_y / incY;
// total number of iterations
auto iters_t = itersX * itersY;
// we are checking the case of number of requested threads was smaller
numThreads = ThreadsHelper::numberOfThreads2d(numThreads, itersX, itersY);
// basic shortcut for no-threading cases
if (numThreads == 1) {
function(0, startX, stopX, incX, startY, stopY, incY);
return 1;
}
// We have couple of scenarios:
// either we split workload along 1st loop, or 2nd
auto splitLoop = ThreadsHelper::pickLoop2d(numThreads, itersX, itersY);
// for debug mode we execute things inplace, without any threads
if (debug) {
for (int e = 0; e < numThreads; e++) {
auto span = Span2::build(splitLoop, e, numThreads, startX, stopX, incX, startY, stopY, incY);
function(e, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
}
// but we still mimic multithreaded execution
return numThreads;
} else {
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
for (int e = 0; e < numThreads; e++) {
auto threadId = numThreads - e - 1;
auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY);
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
}
// block until all threads finish their job
ticket->waitAndRelease();
return numThreads;
} else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY);
// we tell that parallelism request declined
return 1;
}
};
}
int Threads::parallel_for(FUNC_3D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ, uint64_t numThreads) {
if (startX > stopX)
throw std::runtime_error("Threads::parallel_for got startX > stopX");
if (startY > stopY)
throw std::runtime_error("Threads::parallel_for got startY > stopY");
if (startZ > stopZ)
throw std::runtime_error("Threads::parallel_for got startZ > stopZ");
auto delta_x = stopX - startX;
auto delta_y = stopY - startY;
auto delta_z = stopZ - startZ;
auto itersX = delta_x / incX;
auto itersY = delta_y / incY;
auto itersZ = delta_z / incZ;
numThreads = 1; //ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ);
if (numThreads == 1) {
// loop is too small - executing function as is
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
return 1;
}
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
if (ticket != nullptr) {
auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ);
for (int e = 0; e < numThreads; e++) {
auto thread_id = numThreads - e - 1;
auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ());
}
// block until we're done
ticket->waitAndRelease();
// we tell that parallelism request succeeded
return numThreads;
} else {
// if there were no threads available - we'll execute function right within current thread
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
// we tell that parallelism request declined
return 1;
}
}
int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket != nullptr) {
// submit tasks one by one
for (uint64_t e = 0; e < numThreads - 1; e++)
ticket->enqueue(e, numThreads, function);
function(numThreads - 1, numThreads);
ticket->waitAndRelease();
return numThreads;
} else {
// if there's no threads available - we'll execute function sequentially one by one
for (uint64_t e = 0; e < numThreads; e++)
function(e, numThreads);
return numThreads;
}
return numThreads;
}
int64_t Threads::parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) {
if (start > stop)
throw std::runtime_error("Threads::parallel_long got start > stop");
auto delta = (stop - start);
if (delta == 0 || numThreads == 1)
return function(0, start, stop, increment);
auto numElements = delta / increment;
// we decide what's optimal number of threads we need here, and execute it
numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements);
if (numThreads == 1)
return function(0, start, stop, increment);
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket == nullptr)
return function(0, start, stop, increment);
// create temporary array
int64_t intermediatery[256];
auto span = delta / numThreads;
// execute threads in parallel
for (uint32_t e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = span * (e + 1) + start;
if (e == numThreads - 1)
intermediatery[e] = function(e, start_, stop, increment);
else
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
}
ticket->waitAndRelease();
// aggregate results in single thread
for (uint64_t e = 1; e < numThreads; e++)
intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]);
// return accumulated result
return intermediatery[0];
}
double Threads::parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) {
if (start > stop)
throw std::runtime_error("Threads::parallel_long got start > stop");
auto delta = (stop - start);
if (delta == 0 || numThreads == 1)
return function(0, start, stop, increment);
auto numElements = delta / increment;
// we decide what's optimal number of threads we need here, and execute it
numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements);
if (numThreads == 1)
return function(0, start, stop, increment);
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
if (ticket == nullptr)
return function(0, start, stop, increment);
// create temporary array
double intermediatery[256];
auto span = delta / numThreads;
// execute threads in parallel
for (uint32_t e = 0; e < numThreads; e++) {
auto start_ = span * e + start;
auto stop_ = span * (e + 1) + start;
if (e == numThreads - 1)
intermediatery[e] = function(e, start_, stop, increment);
else
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
}
ticket->waitAndRelease();
// aggregate results in single thread
for (uint64_t e = 1; e < numThreads; e++)
intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]);
// return accumulated result
return intermediatery[0];
}
}

View File

@ -0,0 +1,94 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <execution/Ticket.h>
#include <execution/ThreadPool.h>
#include <helpers/logger.h>
#include <array>
namespace samediff {
Ticket::Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues) {
_acquired = true;
_queues = queues;
}
Ticket::Ticket() {
_acquired = true;
_interfaces.resize(nd4j::Environment::getInstance()->maxThreads());
}
bool Ticket::acquired() {
return _acquired;
}
void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) {
_queues[thread_id]->put(callable);
_callables.emplace_back(callable);
}
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func) {
_interfaces[thread_id]->fill(thread_id, num_threads, func);
}
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x) {
_interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x);
}
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x) {
_interfaces[thread_id]->fill(thread_id, num_threads, lpt, func, start_x, stop_x, inc_x);
}
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x) {
_interfaces[thread_id]->fill(thread_id, num_threads, dpt, func, start_x, stop_x, inc_x);
}
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y) {
_interfaces[thread_id]->fill(thread_id, num_threads, std::move(func), start_x, stop_x, inc_x, start_y, stop_y, inc_y);
}
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) {
_interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x, start_y, stop_y, inc_y, start_z, stop_z, inc_z);
}
void Ticket::acquiredThreads(uint32_t threads) {
_acquiredThreads = threads;
}
void Ticket::waitAndRelease() {
for (uint32_t e = 0; e < this->_acquiredThreads; e++) {
// block until finished
_interfaces[e]->waitForCompletion();
// mark available
_interfaces[e]->markAvailable();
// increment availability counter
ThreadPool::getInstance()->release();
}
// return this ticket back to the pool
ThreadPool::getInstance()->release(this);
}
void Ticket::attach(uint32_t thread_id, samediff::CallableInterface *interface) {
_interfaces[thread_id] = interface;
}
}

View File

@ -232,6 +232,7 @@ namespace nd4j {
} }
static nd4j::ops::DeclarableOp* buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar); static nd4j::ops::DeclarableOp* buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar);
static void deleteOpByType(OpType opType, void *op);
}; };
} }
} }

View File

@ -19,6 +19,7 @@
// //
#include <graph/Graph.h> #include <graph/Graph.h>
#include <array/DataTypeUtils.h>
#include <helpers/EnumUtils.h> #include <helpers/EnumUtils.h>
#include <graph/FlatUtils.h> #include <graph/FlatUtils.h>
#include <NativeOps.h> #include <NativeOps.h>
@ -154,7 +155,7 @@ namespace nd4j {
Nd4jLong *newShape = nullptr; Nd4jLong *newShape = nullptr;
// if that's scalar output - we don't care about previous node // if that's scalar output - we don't care about previous node
if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == MAX_INT)) { if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == nd4j::DataTypeUtils::max<int>())) {
newShape = new Nd4jLong[8]; newShape = new Nd4jLong[8];
newShape[0] = 2; newShape[0] = 2;

View File

@ -682,8 +682,9 @@ namespace nd4j {
if (_protoContext != nullptr) if (_protoContext != nullptr)
delete _protoContext; delete _protoContext;
if (_isDeductable && _customOp != nullptr) if (_isDeductable && _customOp != nullptr) {
delete _customOp; Node::deleteOpByType(_opType, _customOp);
}
} }
int nd4j::graph::Node::getRewindNode() { int nd4j::graph::Node::getRewindNode() {
@ -710,6 +711,70 @@ namespace nd4j {
return false; return false;
} }
void nd4j::graph::Node::deleteOpByType(OpType opType, void *op) {
switch (opType) {
case OpType_PAIRWISE:
delete reinterpret_cast<nd4j::ops::LegacyPairwiseTransformOp*>(op);
break;
case OpType_PAIRWISE_BOOL:
delete reinterpret_cast<nd4j::ops::LegacyPairwiseTransformBoolOp*>(op);
break;
case OpType_TRANSFORM_STRICT:
delete reinterpret_cast<nd4j::ops::LegacyTransformStrictOp*>(op);
break;
case OpType_TRANSFORM_SAME:
delete reinterpret_cast<nd4j::ops::LegacyTransformSameOp*>(op);
break;
case OpType_TRANSFORM_FLOAT:
delete reinterpret_cast<nd4j::ops::LegacyTransformFloatOp*>(op);
break;
case OpType_TRANSFORM_BOOL:
delete reinterpret_cast<nd4j::ops::LegacyTransformBoolOp*>(op);
break;
case OpType_SCALAR:
delete reinterpret_cast<nd4j::ops::LegacyScalarOp*>(op);
break;
case OpType_SCALAR_BOOL:
delete reinterpret_cast<nd4j::ops::LegacyScalarBoolOp*>(op);
break;
case OpType_REDUCE_3:
delete reinterpret_cast<nd4j::ops::LegacyReduce3Op*>(op);
break;
case OpType_REDUCE_SAME:
delete reinterpret_cast<nd4j::ops::LegacyReduceSameOp*>(op);
break;
case OpType_REDUCE_FLOAT:
delete reinterpret_cast<nd4j::ops::LegacyReduceFloatOp*>(op);
break;
case OpType_REDUCE_LONG:
delete reinterpret_cast<nd4j::ops::LegacyReduceLongOp*>(op);
break;
case OpType_REDUCE_BOOL:
delete reinterpret_cast<nd4j::ops::LegacyReduceBoolOp*>(op);
break;
case OpType_INDEX_REDUCE:
delete reinterpret_cast<nd4j::ops::LegacyIndexReduceOp*>(op);
break;
case OpType_SUMMARYSTATS:
delete reinterpret_cast<nd4j::ops::LegacyStatsOp*>(op);
break;
case OpType_RANDOM:
delete reinterpret_cast<nd4j::ops::LegacyRandomOp*>(op);
break;
case OpType_BROADCAST:
delete reinterpret_cast<nd4j::ops::LegacyBroadcastOp*>(op);
break;
case OpType_BROADCAST_BOOL:
delete reinterpret_cast<nd4j::ops::LegacyBroadcastBoolOp*>(op);
break;
case OpType_CUSTOM:
delete reinterpret_cast<nd4j::ops::DeclarableOp*>(op);
break;
default:
throw std::runtime_error("Bad opType passed in");
}
}
nd4j::ops::DeclarableOp* nd4j::graph::Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) { nd4j::ops::DeclarableOp* nd4j::graph::Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) {
switch (opType) { switch (opType) {
case OpType_PAIRWISE: case OpType_PAIRWISE:

File diff suppressed because it is too large Load Diff

View File

@ -721,7 +721,7 @@ namespace shape {
INLINEDEF void TAD::createOffsets() { INLINEDEF void TAD::createOffsets() {
this->tadOffsets = new Nd4jLong[this->numTads]; this->tadOffsets = new Nd4jLong[this->numTads];
uint nT = this->numTads; uint nT = this->numTads;
PRAGMA_OMP_PARALLEL_FOR_SIMD
for(uint i = 0; i < nT; i++) for(uint i = 0; i < nT; i++)
this->tadOffsets[i] = this->tadOffset(i); this->tadOffsets[i] = this->tadOffset(i);
} }

View File

@ -19,7 +19,6 @@
// //
#include "../OpBenchmark.h" #include "../OpBenchmark.h"
#include <helpers/BlasHelper.h>
#include <MmulHelper.h> #include <MmulHelper.h>
#ifndef DEV_TESTS_MATRIXBENCHMARK_H #ifndef DEV_TESTS_MATRIXBENCHMARK_H

View File

@ -22,6 +22,7 @@
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <helpers/BlasHelper.h> #include <helpers/BlasHelper.h>
#include <exceptions/datatype_exception.h> #include <exceptions/datatype_exception.h>
#include <execution/Threads.h>
namespace nd4j { namespace nd4j {
@ -74,26 +75,28 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
// } // }
// } // }
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2)) auto func = PRAGMA_THREADS_FOR_2D { ;
for(uint row = 0; row < M; ++row) { for (auto row = start_x; row < stop_x; row += inc_x) {
for(uint col = 0; col < N; ++col) { for (auto col = start_y; col < stop_y; col += inc_y) {
T3 *c = flagC ? (C + row + col * ldc) : (C + row * ldc + col);
T3* c = flagC ? (C + row + col * ldc) : (C + row * ldc + col); T3 val = 0;
T3 val = 0;
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for(uint i = 0; i < K; ++i) { for (uint i = 0; i < K; ++i) {
T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda); T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda);
T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i); T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i);
val += alphaZ * a * b; val += alphaZ * a * b;
}
if (betaZ)
*c = val + betaZ * *c;
else
*c = val;
} }
}
if(betaZ) };
*c = val + betaZ * *c;
else samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1);
*c = val;
}
}
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -108,24 +111,27 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
const bool flagA = aOrder == 'f'; const bool flagA = aOrder == 'f';
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) auto func = PRAGMA_THREADS_FOR {
for(int row = 0; row < M; ++row) { for (auto row = start; row < stop; row += increment) {
T3* y = Y + row * incy;
T3 val = 0;
PRAGMA_OMP_SIMD T3 *y = Y + row * incy;
for(int i = 0; i < N; ++i) { T3 val = 0;
T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i);
T3 x = *(X + i * incx); PRAGMA_OMP_SIMD
val += alphaZ * a * x; for (int i = 0; i < N; ++i) {
T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i);
T3 x = *(X + i * incx);
val += alphaZ * a * x;
}
if (betaZ)
*y = val + betaZ * *y;
else
*y = val;
} }
};
if(betaZ)
*y = val + betaZ * *y; samediff::Threads::parallel_for(func, 0, M);
else
*y = val;
}
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -141,7 +147,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
T3 sum = 0; T3 sum = 0;
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
for(int i = 0; i < length; ++i) for(int i = 0; i < length; ++i)
sum = sum + X[i * incx] * Y[i * incy]; sum += X[i * incx] * Y[i * incy];
*Z = alphaZ * sum + betaZ * *Z; *Z = alphaZ * sum + betaZ * *Z;
} }

View File

@ -19,6 +19,7 @@
// //
#include <TrueBroadcastHelper.h> #include <TrueBroadcastHelper.h>
#include <ops/ops.h>
using namespace simdOps; using namespace simdOps;

View File

@ -44,62 +44,67 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const Nd4jLong* tadShape = shape::shapeOf(const_cast<Nd4jLong*>(tadShapeInfo)); const Nd4jLong* tadShape = shape::shapeOf(const_cast<Nd4jLong*>(tadShapeInfo));
const Nd4jLong* tadStride = shape::stride(const_cast<Nd4jLong*>(tadShapeInfo)); const Nd4jLong* tadStride = shape::stride(const_cast<Nd4jLong*>(tadShapeInfo));
int tadsPerThread = zLen / TAD_THRESHOLD;
int numThreads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
numThreads = nd4j::math::nd4j_min<int>(numThreads, omp_get_max_threads());
switch (kindOfLoop) { switch (kindOfLoop) {
//*********************************************// //*********************************************//
case nd4j::LoopKind::EWS1: { case nd4j::LoopKind::EWS1: {
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; i++) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
functions::indexreduce::IndexValue<X> comp(tad[j], j); functions::indexreduce::IndexValue<X> comp(tad[j], j);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i] = (Z) indexValue.index;
} }
};
z[i] = (Z) indexValue.index; samediff::Threads::parallel_tad(func, 0, zLen);
}
} }
break; break;
//*********************************************// //*********************************************//
case nd4j::LoopKind::EWSNONZERO: { case nd4j::LoopKind::EWSNONZERO: {
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; i++) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j); functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i * zEws] = (Z) indexValue.index;
} }
};
z[i * zEws] = (Z) indexValue.index; samediff::Threads::parallel_tad(func, 0, zLen);
}
} }
break; break;
//*********************************************// //*********************************************//
case nd4j::LoopKind::RANK1: { case nd4j::LoopKind::RANK1: {
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; i++) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint i0 = 0; i0 < tadLen; ++i0) { for (uint i0 = 0; i0 < tadLen; ++i0) {
functions::indexreduce::IndexValue<X> comp(tad[i0 * tadStride[0]], i0); functions::indexreduce::IndexValue<X> comp(tad[i0 * tadStride[0]], i0);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i] = (Z) indexValue.index;
} }
};
z[i] = (Z) indexValue.index; samediff::Threads::parallel_tad(func, 0, zLen);
}
} }
break; break;
@ -108,22 +113,25 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
Nd4jLong newStride[2]; Nd4jLong newStride[2];
shape::updateStrides(2, tadShape, newStride, 'c'); shape::updateStrides(2, tadShape, newStride, 'c');
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; ++i) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint i0 = 0; i0 < tadShape[0]; ++i0) { for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
for (uint i1 = 0; i1 < tadShape[1]; ++i1) { for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1]; const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1];
const auto tadIndex = i0 * newStride[0] + i1; const auto tadIndex = i0 * newStride[0] + i1;
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
} }
}
z[i] = (Z) indexValue.index; z[i] = (Z) indexValue.index;
} }
};
samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -132,24 +140,27 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
Nd4jLong newStride[3]; Nd4jLong newStride[3];
shape::updateStrides(3, tadShape, newStride, 'c'); shape::updateStrides(3, tadShape, newStride, 'c');
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; ++i) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint i0 = 0; i0 < tadShape[0]; ++i0) { for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
for (uint i1 = 0; i1 < tadShape[1]; ++i1) { for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
for (uint i2 = 0; i2 < tadShape[2]; ++i2) { for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]; const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2];
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2; const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2;
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
} }
} }
}
z[i] = (Z) indexValue.index; z[i] = (Z) indexValue.index;
} }
};
samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -158,26 +169,29 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
Nd4jLong newStride[4]; Nd4jLong newStride[4];
shape::updateStrides(4, tadShape, newStride, 'c'); shape::updateStrides(4, tadShape, newStride, 'c');
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; ++i) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint i0 = 0; i0 < tadShape[0]; ++i0) { for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
for (uint i1 = 0; i1 < tadShape[1]; ++i1) { for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
for (uint i2 = 0; i2 < tadShape[2]; ++i2) { for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
for (uint i3 = 0; i3 < tadShape[3]; ++i3) { for (uint i3 = 0; i3 < tadShape[3]; ++i3) {
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3]; const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3];
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3; const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3;
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
} }
} }
} }
}
z[i] = (Z) indexValue.index; z[i] = (Z) indexValue.index;
} }
};
samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -186,28 +200,31 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
Nd4jLong newStride[5]; Nd4jLong newStride[5];
shape::updateStrides(5, tadShape, newStride, 'c'); shape::updateStrides(5, tadShape, newStride, 'c');
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; ++i) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint i0 = 0; i0 < tadShape[0]; ++i0) { for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
for (uint i1 = 0; i1 < tadShape[1]; ++i1) { for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
for (uint i2 = 0; i2 < tadShape[2]; ++i2) { for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
for (uint i3 = 0; i3 < tadShape[3]; ++i3) { for (uint i3 = 0; i3 < tadShape[3]; ++i3) {
for (uint i4 = 0; i4 < tadShape[4]; ++i4) { for (uint i4 = 0; i4 < tadShape[4]; ++i4) {
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4]; const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4];
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3 * newStride[3] + i4; const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3 * newStride[3] + i4;
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
} }
} }
} }
} }
}
z[i] = (Z) indexValue.index; z[i] = (Z) indexValue.index;
} }
};
samediff::Threads::parallel_tad(func, 0, zLen);
} }
break; break;
@ -216,19 +233,22 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
uint castZShapeInfo[MAX_RANK]; uint castZShapeInfo[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; i++) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j); functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
z[zOffset] = (Z) indexValue.index;
} }
};
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); samediff::Threads::parallel_tad(func, 0, zLen);
z[zOffset] = (Z) indexValue.index;
}
} }
break; break;
@ -237,19 +257,22 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
uint castTadShapeInfo[MAX_RANK]; uint castTadShapeInfo[MAX_RANK];
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo); const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; i++) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
z[i * zEws] = (Z) indexValue.index;
} }
};
z[i * zEws] = (Z) indexValue.index; samediff::Threads::parallel_tad(func, 0, zLen);
}
} }
break; break;
@ -260,20 +283,23 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo); const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads) auto func = PRAGMA_THREADS_FOR {
for (uint i = 0; i < zLen; i++) { for (auto i = start; i < stop; i += increment) {
auto tad = const_cast<X*>(x) + tadOffsets[i]; auto tad = const_cast<X *>(x) + tadOffsets[i];
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad); auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
}
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
z[zOffset] = (Z) indexValue.index;
} }
};
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ); samediff::Threads::parallel_tad(func, 0, zLen);
z[zOffset] = (Z) indexValue.index;
}
} }
} }
} }

View File

@ -28,24 +28,32 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
} }
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) { void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_0); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_0);

View File

@ -28,24 +28,32 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
} }
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) { void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_1); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_1);

View File

@ -28,24 +28,32 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
} }
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) { void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_2); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_2);

View File

@ -28,24 +28,32 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
#endif
} }
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) { void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams); #ifndef INLINE_LOOPS
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) { void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) { void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_3); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT Reduction3Loops, , LIBND4J_TYPES, FLOAT_TYPES_3);

View File

@ -19,3 +19,4 @@
// //
#include <helpers/Loops.h> #include <helpers/Loops.h>
#include <op_boilerplate.h>

View File

@ -26,16 +26,20 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionBoolLoops<X, Z>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) { void ReductionBoolLoops<X, Z>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams); #ifndef INLINE_LOOPS
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionBoolLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionBoolLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong *tadOffsets,
X *extraParams) { X *extraParams, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), REDUCE_BOOL_OPS); #ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionBoolLoops, , LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionBoolLoops, , LIBND4J_TYPES, BOOL_TYPES);

View File

@ -28,16 +28,19 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) { void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams); #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams) { Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_0); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_0);

View File

@ -28,16 +28,19 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) { void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams); #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams) { Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_1); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_1);

View File

@ -28,16 +28,19 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) { void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams); #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams) { Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_2); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_2);

View File

@ -28,16 +28,19 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) { void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams); #ifndef INLINE_LOOPS
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Y *extraParams) { Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_3); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionFloatLoops, , LIBND4J_TYPES, FLOAT_TYPES_3);

View File

@ -33,16 +33,19 @@ namespace nd4j {
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
void ReductionLongLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z *z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) { void ReductionLongLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z *z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams); #ifndef INLINE_LOOPS
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
} }
template<typename X, typename Y> template<typename X, typename Y>
void ReductionLongLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z, void ReductionLongLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, X *extraParams) { Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_LONG_OPS); DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS);
#endif
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionLongLoops, , LIBND4J_TYPES, LONG_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReductionLongLoops, , LIBND4J_TYPES, LONG_TYPES);

View File

@ -26,20 +26,24 @@ namespace nd4j {
template<typename X> template<typename X>
template <typename OpType> template <typename OpType>
void ReductionSameLoops<X>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) { void ReductionSameLoops<X>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams); #ifndef INLINE_LOOPS
ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
} }
template<typename X> template<typename X>
void ReductionSameLoops<X>::wrapper(const int opNum, X *vx, Nd4jLong *xShapeInfo, X *vz, void ReductionSameLoops<X>::wrapper(const int opNum, X *vx, Nd4jLong *xShapeInfo, X *vz,
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets, Nd4jLong *tadOffsets,
X *vextraParams) { X *vextraParams, int64_t start, int64_t stop) {
#ifndef INLINE_LOOPS
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), REDUCE_SAME_OPS); DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_SAME_OPS);
#endif
} }
BUILD_SINGLE_TEMPLATE(template class ReductionSameLoops, , LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template class ReductionSameLoops, , LIBND4J_TYPES);

View File

@ -24,6 +24,7 @@
#include <execution/LaunchContext.h> #include <execution/LaunchContext.h>
#include <specials.h> #include <specials.h>
#include <logger.h> #include <logger.h>
#include <ops/ops.h>
// #include <cuda_runtime.h> // #include <cuda_runtime.h>
// #include <cuda.h> // #include <cuda.h>

View File

@ -34,16 +34,16 @@ namespace nd4j {
auto numHeads = projectionMatrix->sizeAt(0); auto numHeads = projectionMatrix->sizeAt(0);
auto projectedSize = projectionMatrix->sizeAt(1); auto projectedSize = projectionMatrix->sizeAt(1);
auto inputPerm = input->permute({1, 0, 2}); auto inputPerm = input->permute({1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps]
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps]
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn]
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
nd4j::ops::matmul mmul; nd4j::ops::matmul mmul;
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {}); mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
projected.permutei({2, 0, 1, 3}); projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
return projected; return projected;
} }

View File

@ -74,7 +74,7 @@ namespace nd4j {
template <> template <>
bool BlasHelper::hasGEMV<float>() { bool BlasHelper::hasGEMV<float>() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasSgemv; return _hasSgemv;
@ -83,7 +83,7 @@ namespace nd4j {
template <> template <>
bool BlasHelper::hasGEMV<double>() { bool BlasHelper::hasGEMV<double>() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasDgemv; return _hasDgemv;
@ -132,14 +132,14 @@ namespace nd4j {
bool BlasHelper::hasGEMV(const nd4j::DataType dtype) { bool BlasHelper::hasGEMV(const nd4j::DataType dtype) {
if(dtype == DataType::FLOAT32) { if(dtype == DataType::FLOAT32) {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasSgemv; return _hasSgemv;
#endif #endif
} }
if(dtype == DataType::DOUBLE) { if(dtype == DataType::DOUBLE) {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasDgemv; return _hasDgemv;
@ -150,7 +150,7 @@ namespace nd4j {
template <> template <>
bool BlasHelper::hasGEMM<float>() { bool BlasHelper::hasGEMM<float>() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasSgemm; return _hasSgemm;
@ -159,7 +159,7 @@ namespace nd4j {
template <> template <>
bool BlasHelper::hasGEMM<double>() { bool BlasHelper::hasGEMM<double>() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasDgemm; return _hasDgemm;
@ -208,14 +208,14 @@ namespace nd4j {
bool BlasHelper:: hasGEMM(const nd4j::DataType dtype) { bool BlasHelper:: hasGEMM(const nd4j::DataType dtype) {
if(dtype == DataType::FLOAT32) { if(dtype == DataType::FLOAT32) {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasSgemm; return _hasSgemm;
#endif #endif
} }
if(dtype == DataType::DOUBLE) { if(dtype == DataType::DOUBLE) {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
return _hasDgemm; return _hasDgemm;
@ -276,14 +276,14 @@ namespace nd4j {
} }
CblasSgemv BlasHelper::sgemv() { CblasSgemv BlasHelper::sgemv() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__)|| defined(HAVE_OPENBLAS)
return (CblasSgemv)&cblas_sgemv; return (CblasSgemv)&cblas_sgemv;
#else #else
return this->cblasSgemv; return this->cblasSgemv;
#endif #endif
} }
CblasDgemv BlasHelper::dgemv() { CblasDgemv BlasHelper::dgemv() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return (CblasDgemv)&cblas_dgemv; return (CblasDgemv)&cblas_dgemv;
#else #else
return this->cblasDgemv; return this->cblasDgemv;
@ -291,7 +291,7 @@ namespace nd4j {
} }
CblasSgemm BlasHelper::sgemm() { CblasSgemm BlasHelper::sgemm() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return (CblasSgemm)&cblas_sgemm; return (CblasSgemm)&cblas_sgemm;
#else #else
return this->cblasSgemm; return this->cblasSgemm;
@ -299,7 +299,7 @@ namespace nd4j {
} }
CblasDgemm BlasHelper::dgemm() { CblasDgemm BlasHelper::dgemm() {
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return (CblasDgemm)&cblas_dgemm; return (CblasDgemm)&cblas_dgemm;
#else #else
return this->cblasDgemm; return this->cblasDgemm;

View File

@ -23,6 +23,7 @@
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <ops/declarable/headers/parity_ops.h> #include <ops/declarable/headers/parity_ops.h>
#include <helpers/DebugInfo.h> #include <helpers/DebugInfo.h>
#include <execution/Threads.h>
namespace nd4j { namespace nd4j {
DebugInfo DebugHelper::debugStatistics(NDArray const* input) { DebugInfo DebugHelper::debugStatistics(NDArray const* input) {
@ -88,11 +89,18 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_m
} }
*info = {_minValue, _maxValue, _meanValue / input->lengthOf(), _stdDevValue, _zeroCount, _positiveCount, _negativeCount, _infCount, _nanCount}; *info = {_minValue, _maxValue, _meanValue / input->lengthOf(), _stdDevValue, _zeroCount, _positiveCount, _negativeCount, _infCount, _nanCount};
_stdDevValue = 0; //math::nd4j_sqrt<double, double>(info->_stdDevValue / (input->lengthOf() - 1)); _stdDevValue = 0; //math::nd4j_sqrt<double, double>(info->_stdDevValue / (input->lengthOf() - 1));
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule (static) reduction(+:_stdDevValue))
for (Nd4jLong e = 0; e < input->lengthOf(); e++) { auto func = PRAGMA_REDUCE_DOUBLE {
double current = input->e<double>(e); auto _stdDevValue = 0.0;
_stdDevValue += (info->_meanValue - current) * (info->_meanValue - current); //info->_minValue; for (auto e = start; e < stop; e++) {
} double current = input->e<double>(e);
_stdDevValue += (info->_meanValue - current) * (info->_meanValue - current); //info->_minValue;
}
return _stdDevValue;
};
_stdDevValue = samediff::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf());
info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf()); info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf());
} }

View File

@ -33,13 +33,11 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
switch(loss) { switch(loss) {
case MEAN: case MEAN:
PRAGMA_OMP_PARALLEL_FOR_IF(numInGradArrs > 1)
for(int i = 0; i < numInGradArrs; ++i) for(int i = 0; i < numInGradArrs; ++i)
*gradArrs[i] = 1. / gradArrs[i]->lengthOf(); *gradArrs[i] = 1. / gradArrs[i]->lengthOf();
break; break;
case SUM: case SUM:
PRAGMA_OMP_PARALLEL_FOR_IF(numInGradArrs > 1)
for(int i = 0; i < numInGradArrs; ++i) for(int i = 0; i < numInGradArrs; ++i)
*gradArrs[i] = 1.; *gradArrs[i] = 1.;
break; break;

View File

@ -45,7 +45,7 @@ OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) {
else else
desiredNumThreads = nd4j::math::nd4j_min<int>(omp_get_max_threads(), desiredNumThreads); desiredNumThreads = nd4j::math::nd4j_min<int>(omp_get_max_threads(), desiredNumThreads);
#else #else
desiredNumThreads = 1; desiredNumThreads = nd4j::Environment::getInstance()->maxThreads();
#endif #endif
_numThreads = nd4j::math::nd4j_min<int>(N / maxItersPerThread, desiredNumThreads); _numThreads = nd4j::math::nd4j_min<int>(N / maxItersPerThread, desiredNumThreads);
} }
@ -75,7 +75,7 @@ Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) {
#ifdef _OPENMP #ifdef _OPENMP
return betterThreads(N, omp_get_max_threads()); return betterThreads(N, omp_get_max_threads());
#else #else
return 1; return betterThreads(N, nd4j::Environment::getInstance()->maxThreads());;
#endif #endif
} }
@ -92,7 +92,7 @@ Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) {
#ifdef _OPENMP #ifdef _OPENMP
auto maxThreads = omp_get_max_threads(); auto maxThreads = omp_get_max_threads();
#else #else
auto maxThreads = 1; auto maxThreads = nd4j::Environment::getInstance()->maxThreads();
#endif #endif
// if there's only 1 thread allowed - nothing to do here // if there's only 1 thread allowed - nothing to do here

View File

@ -1,66 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_AGGREGATES_H
#define LIBND4J_AGGREGATES_H
#include <ops/aggregate_ops.h>
#include <helpers/DebugHelper.h>
#include <helpers/helper_ptrmap.h>
namespace functions {
namespace aggregate {
template<typename X>
class AggregatedFunction {
public:
#ifdef __CUDACC__
template<typename OpClass>
__device__ static void execCuda(X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments);
__device__ static void execCuda(int opNum, X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments);
__device__ static void aggregateBatch(int numAggregates, int opNum, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments);
__host__ static void aggregateBatchKernelGeneric(dim3& launchDims, cudaStream_t *stream, int opNum, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments);
__host__ static void aggregateKernelGeneric(dim3& launchDims, cudaStream_t *stream, int opNum, void **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, void *realArguments, int numRealArguments);
#endif
template<typename OpClass>
inline static void exec(X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments) {
OpClass::executeAggregate(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
}
inline static void exec(int opNum, X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), AGGREGATE_OPS);
}
};
}
}
#ifdef __CUDACC__
#endif
#endif //LIBND4J_AGGREGATES_H

View File

@ -91,7 +91,7 @@ namespace functions {
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
#endif #else
static void execInverse(int opNum, static void execInverse(int opNum,
void *x, void *x,
@ -105,7 +105,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void exec(int opNum, static void exec(int opNum,
void *x, void *x,
@ -119,7 +121,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
/** /**
* CPU execution * CPU execution
@ -144,7 +148,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
template<typename OpType> template<typename OpType>
static void execInverse(void *x, static void execInverse(void *x,
@ -158,7 +164,10 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
#endif
}; };
} }
} }

View File

@ -89,7 +89,7 @@ namespace functions {
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
#endif #else
static void exec(int opNum, static void exec(int opNum,
void *x, void *x,
@ -103,7 +103,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void execInverse(int opNum, static void execInverse(int opNum,
void *x, void *x,
@ -117,7 +119,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
/** /**
* CPU execution * CPU execution
@ -142,7 +146,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
template<typename OpType> template<typename OpType>
static void execInverse(void *x, static void execInverse(void *x,
@ -156,7 +162,10 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
#endif
}; };
} }
} }

View File

@ -89,7 +89,7 @@ namespace functions {
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
#endif #else
static void exec(int opNum, static void exec(int opNum,
void *x, void *x,
@ -103,7 +103,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void execInverse(int opNum, static void execInverse(int opNum,
void *x, void *x,
@ -117,7 +119,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
/** /**
* CPU execution * CPU execution
@ -142,7 +146,9 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
template<typename OpType> template<typename OpType>
static void execInverse(void *x, static void execInverse(void *x,
@ -156,7 +162,10 @@ namespace functions {
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset, Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ); Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
#endif
}; };
} }
} }

View File

@ -24,6 +24,7 @@
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -43,7 +44,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x, DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -55,7 +58,7 @@ namespace functions {
xTadShapeInfo, xTadShapeInfo,
xTadOffset, xTadOffset,
zTadShapeInfo, zTadShapeInfo,
zTadOffset), BROADCAST_OPS); zTadOffset, start, stop), BROADCAST_OPS);
} }
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
@ -71,7 +74,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -83,7 +88,7 @@ namespace functions {
xTadShapeInfo, xTadShapeInfo,
xTadOffset, xTadOffset,
zTadShapeInfo, zTadShapeInfo,
zTadOffset), BROADCAST_OPS); zTadOffset, start, stop), BROADCAST_OPS);
} }
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
@ -99,7 +104,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<Y *>(vy);
@ -131,10 +138,6 @@ namespace functions {
auto lenZ = shape::length(zTadShapeInfo); auto lenZ = shape::length(zTadShapeInfo);
auto lenY = shape::length(yShapeInfo); auto lenY = shape::length(yShapeInfo);
int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo); auto zEws = shape::elementWiseStride(zTadShapeInfo);
@ -142,19 +145,17 @@ namespace functions {
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) { auto oX = x + tadOffsets[i];
auto oX = x + tadOffsets[i]; auto oZ = z + zTadOffset[i];
auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], y[f]); oZ[f] = OpType::op(oX[f], y[f]);
} }
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO){ else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO){
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
@ -164,13 +165,10 @@ namespace functions {
} }
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
@ -182,70 +180,61 @@ namespace functions {
} }
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
for (auto i = start; i < stop; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]); oZ[zOffset] = OpType::op(oX[offset], y[offset]);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]); oZ[offset] = OpType::op(oX[offset], y[yOffset]);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]); oZ[offset] = OpType::op(oX[xOffset], y[offset]);
} }
} }
} }
else { else {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
@ -253,17 +242,15 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
} }
} }
@ -285,7 +272,9 @@ namespace functions {
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset, Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<Y *>(vy);
@ -319,7 +308,7 @@ namespace functions {
int tadsPerThread = tads / TAD_THRESHOLD; int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread); int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads()); threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -328,8 +317,7 @@ namespace functions {
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
if(kindOfLoop == nd4j::LoopKind::EWS1) { if(kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (unsigned int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
@ -339,24 +327,20 @@ namespace functions {
} }
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
} };
} }
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oY = x + tadOffsets[i]; auto oY = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
@ -365,73 +349,63 @@ namespace functions {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]); oZ[offset] = OpType::op(x[offset], oY[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]); oZ[zOffset] = OpType::op(x[offset], oY[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]); oZ[offset] = OpType::op(x[xOffset], oY[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]); oZ[offset] = OpType::op(x[offset], oY[yOffset]);
} }
} };
} }
else { else {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
@ -439,20 +413,18 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
} }
} };
} }
} }
} }

View File

@ -24,6 +24,7 @@
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -43,7 +44,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -55,7 +58,7 @@ namespace functions {
xTadShapeInfo, xTadShapeInfo,
xTadOffset, xTadOffset,
zTadShapeInfo, zTadShapeInfo,
zTadOffset), BROADCAST_BOOL_OPS); zTadOffset, start, stop), BROADCAST_BOOL_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
@ -71,7 +74,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x, DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -83,7 +88,7 @@ namespace functions {
xTadShapeInfo, xTadShapeInfo,
xTadOffset, xTadOffset,
zTadShapeInfo, zTadShapeInfo,
zTadOffset), BROADCAST_BOOL_OPS); zTadOffset, start, stop), BROADCAST_BOOL_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
@ -99,7 +104,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
@ -133,7 +140,7 @@ namespace functions {
int tadsPerThread = tads / TAD_THRESHOLD; int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread); int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads()); threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
@ -142,10 +149,9 @@ namespace functions {
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i++) {
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
@ -153,101 +159,86 @@ namespace functions {
} }
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
} };
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
// TODO: cover this codebranch with tests
// all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]); oZ[offset] = OpType::op(oX[offset], y[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]); oZ[zOffset] = OpType::op(oX[offset], y[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]); oZ[offset] = OpType::op(oX[offset], y[yOffset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]); oZ[offset] = OpType::op(oX[xOffset], y[offset]);
} }
} };
} }
else { else {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
@ -255,20 +246,18 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
} }
} };
} }
} }
@ -286,7 +275,9 @@ namespace functions {
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset, Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
@ -320,7 +311,7 @@ namespace functions {
int tadsPerThread = tads / TAD_THRESHOLD; int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread); int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads()); threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -329,8 +320,7 @@ namespace functions {
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
@ -340,8 +330,7 @@ namespace functions {
} }
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
@ -355,14 +344,10 @@ namespace functions {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
// TODO: cover this codebranch with tests
// all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
@ -377,15 +362,13 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]); oZ[zOffset] = OpType::op(x[offset], oY[offset]);
} }
@ -398,15 +381,13 @@ namespace functions {
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]); oZ[offset] = OpType::op(x[xOffset], oY[offset]);
} }
@ -419,16 +400,14 @@ namespace functions {
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]); oZ[offset] = OpType::op(x[offset], oY[yOffset]);
} }
} }
@ -442,9 +421,7 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];

View File

@ -24,6 +24,7 @@
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -43,7 +44,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -55,7 +58,7 @@ namespace functions {
xTadShapeInfo, xTadShapeInfo,
xTadOffset, xTadOffset,
zTadShapeInfo, zTadShapeInfo,
zTadOffset), BROADCAST_INT_OPS); zTadOffset, start, stop), BROADCAST_INT_OPS);
} }
template <typename X> template <typename X>
@ -71,7 +74,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x, DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
@ -83,7 +88,7 @@ namespace functions {
xTadShapeInfo, xTadShapeInfo,
xTadOffset, xTadOffset,
zTadShapeInfo, zTadShapeInfo,
zTadOffset), BROADCAST_INT_OPS); zTadOffset, start, stop), BROADCAST_INT_OPS);
} }
template <typename X> template <typename X>
@ -99,7 +104,9 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset, Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
@ -133,7 +140,7 @@ namespace functions {
int tadsPerThread = tads / TAD_THRESHOLD; int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread); int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads()); threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
@ -142,112 +149,95 @@ namespace functions {
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], y[f]); oZ[f] = OpType::op(oX[f], y[f]);
} };
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
} };
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
// TODO: cover this codebranch with tests
// all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]); oZ[offset] = OpType::op(oX[offset], y[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]); oZ[zOffset] = OpType::op(oX[offset], y[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]); oZ[offset] = OpType::op(oX[offset], y[yOffset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]); oZ[offset] = OpType::op(oX[xOffset], y[offset]);
} }
} };
} }
else { else {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
@ -255,20 +245,18 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
} }
} };
} }
} }
@ -286,7 +274,9 @@ namespace functions {
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset, Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) { Nd4jLong *zTadOffset,
uint64_t start,
uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
@ -320,7 +310,7 @@ namespace functions {
int tadsPerThread = tads / TAD_THRESHOLD; int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread); int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads()); threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
@ -329,46 +319,39 @@ namespace functions {
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(x[f], oY[f]); oZ[f] = OpType::op(x[f], oY[f]);
} };
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint f = 0; f < tadLength; f++) for (uint f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
} };
} }
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
// TODO: cover this codebranch with tests
// all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]); oZ[offset] = OpType::op(x[offset], oY[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
@ -377,64 +360,54 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]); oZ[zOffset] = OpType::op(x[offset], oY[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]); oZ[offset] = OpType::op(x[xOffset], oY[offset]);
} }
} };
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]); oZ[offset] = OpType::op(x[offset], oY[yOffset]);
} }
} };
} }
else { else {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK]; uint tadShapeInfoZCast[MAX_RANK];
@ -442,9 +415,7 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) for (auto i = start; i < stop; i ++) {
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i]; auto oY = y + tadOffsets[i];
@ -455,7 +426,7 @@ namespace functions {
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
} }
} };
} }
} }

View File

@ -23,6 +23,7 @@
#include <Loops.h> #include <Loops.h>
#include <types/types.h> #include <types/types.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include "../legacy_ops.h" #include "../legacy_ops.h"
using namespace simdOps; using namespace simdOps;
@ -44,8 +45,7 @@ void IndexReduce<X,Y>::exec(const int opNum,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -64,42 +64,41 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
IndexValue<X> intermediatery[64];
for (int e = 0; e < maxThreads; e++)
intermediatery[e].index = -1;
if (xEws == 1) { if (xEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{ intermediatery[thread_id] = OpType::startingIndexValue(x);
auto local = OpType::startingIndexValue(x);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = info.getItersPerThread(threadNum); for (auto i = start; i < stop; i += increment) {
IndexValue<X> curr(x[i], i);
for (Nd4jLong i = 0; i < ulen; i++) { intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
IndexValue<X> curr(x[i + threadOffset], i + threadOffset);
local = OpType::update(local, curr, extraParams);
} }
};
maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
for (int e = 0; e < maxThreads; e++)
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
PRAGMA_OMP_CRITICAL
startingIndex = OpType::update(startingIndex, local, extraParams);
}
} else { } else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{ intermediatery[thread_id] = OpType::startingIndexValue(x);
auto local = OpType::startingIndexValue(x);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = info.getItersPerThread(threadNum); for (auto i = start; i < stop; i += increment) {
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
for (Nd4jLong i = 0; i < ulen; i++) { IndexValue<X> curr(x[offset], i);
auto offset = shape::indexOffset(threadOffset + i, xShapeInfo, xShapeInfoCast, canCastX); intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
IndexValue<X> curr(x[offset], threadOffset + i);
local = OpType::update(local, curr, extraParams);
} }
};
PRAGMA_OMP_CRITICAL maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
startingIndex = OpType::update(startingIndex, local, extraParams);
} for (int e = 0; e < maxThreads; e++)
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
} }
return startingIndex.index; return startingIndex.index;
} }
@ -124,9 +123,10 @@ void IndexReduce<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto indexValue = OpType::startingIndexValue(x); const auto indexValue = OpType::startingIndexValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < zLen; i++) for (uint i = 0; i < zLen; i++)
z[i] = (Z) indexValue.index;; z[i] = (Z) indexValue.index;
return; return;
} }

View File

@ -26,6 +26,7 @@
#include <helpers/shape.h> #include <helpers/shape.h>
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <OmpLaunchHelper.h> #include <OmpLaunchHelper.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -42,7 +43,9 @@ namespace functions {
void *z, void *z,
Nd4jLong zEws, Nd4jLong zEws,
void *extraParams, void *extraParams,
Nd4jLong n) { Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xEws, xEws,
y, y,
@ -50,7 +53,7 @@ namespace functions {
z, z,
zEws, zEws,
extraParams, extraParams,
n), PAIRWISE_TRANSFORM_OPS); n, start, stop), PAIRWISE_TRANSFORM_OPS);
}; };
@ -61,48 +64,24 @@ namespace functions {
void *vy, Nd4jLong yEws, void *vy, Nd4jLong yEws,
void *vz, Nd4jLong zEws, void *vz, Nd4jLong zEws,
void *vextraParams, void *vextraParams,
const Nd4jLong n) { const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<Y *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
nd4j::OmpLaunchHelper info(n);
if (xEws == 1 && yEws == 1 && zEws == 1) { if (xEws == 1 && yEws == 1 && zEws == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i] = OpType::op(x[i], y[i], extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto yi = y + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], yi[i], extraParams);
}
} }
else { else {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto yi = y + yEws*threadOffset;
auto zi = z + zEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
}
} }
} }
@ -115,14 +94,16 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams) { void *extraParams,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
yShapeInfo, yShapeInfo,
z, z,
zShapeInfo, zShapeInfo,
extraParams), extraParams, start, stop),
PAIRWISE_TRANSFORM_OPS); PAIRWISE_TRANSFORM_OPS);
}; };
@ -136,7 +117,9 @@ namespace functions {
Nd4jLong* yShapeInfo, Nd4jLong* yShapeInfo,
void *vz, void *vz,
Nd4jLong* zShapeInfo, Nd4jLong* zShapeInfo,
void *vextraParams) { void *vextraParams,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<Y *>(vy);
@ -148,7 +131,6 @@ namespace functions {
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo);
nd4j::OmpLaunchHelper info(n);
if (shape::isScalar(yShapeInfo)) { if (shape::isScalar(yShapeInfo)) {
@ -156,38 +138,22 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for(auto i = start; i < stop; i++) {
{ auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadNum = omp_get_thread_num(); z[offset] = OpType::op(x[offset], y[0], extraParams);
auto threadOffset = info.getThreadOffset(threadNum); };
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for(unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[0], extraParams);
}
}
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for(auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
};
PRAGMA_OMP_SIMD
for(unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
}
}
} }
return; return;
} }
@ -198,96 +164,63 @@ namespace functions {
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop);
} }
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop);
} }
else { else {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); z[offset] = OpType::op(x[offset], y[offset], extraParams);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[offset], extraParams);
}
} }
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
};
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
};
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
};
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
}
}
} }
else { else {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
@ -295,20 +228,13 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
PRAGMA_OMP_SIMD };
for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
}
} }
} }
} }

View File

@ -1,106 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by remote on 2018-09-20.
//
#include <ops/ops.h>
#include <loops/pairwise_transform.h>
#include <types/types.h>
#include <templatemath.h>
#include <helpers/shape.h>
#include <op_boilerplate.h>
#include <OmpLaunchHelper.h>
using namespace simdOps;
namespace functions {
namespace pairwise_transforms {
template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec(
const int opNum,
void *x,
Nd4jLong xEws,
void *y,
Nd4jLong yEws,
void *z,
Nd4jLong zEws,
void *extraParams,
Nd4jLong n) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
xEws,
y,
yEws,
z,
zEws,
extraParams,
n), PAIRWISE_TRANSFORM_OPS);
};
template <typename X, typename Y, typename Z>
template <typename OpType>
void PairWiseTransform<X, Y, Z>::exec(void *vx, Nd4jLong xEws,
void *vy, Nd4jLong yEws,
void *vz, Nd4jLong zEws,
void *vextraParams,
const Nd4jLong n) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
nd4j::OmpLaunchHelper info(n);
if (xEws == 1 && yEws == 1 && zEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto yi = y + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], yi[i], extraParams);
}
}
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto yi = y + yEws*threadOffset;
auto zi = z + zEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
}
}
}
}
}

View File

@ -22,6 +22,7 @@
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <OmpLaunchHelper.h> #include <OmpLaunchHelper.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -38,7 +39,9 @@ namespace functions {
void *z, void *z,
Nd4jLong zEws, Nd4jLong zEws,
void *extraParams, void *extraParams,
Nd4jLong n) { Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xEws, xEws,
y, y,
@ -46,7 +49,7 @@ namespace functions {
z, z,
zEws, zEws,
extraParams, extraParams,
n), PAIRWISE_BOOL_OPS); n, start, stop), PAIRWISE_BOOL_OPS);
}; };
@ -60,46 +63,24 @@ namespace functions {
void *vz, void *vz,
Nd4jLong zEws, Nd4jLong zEws,
void *vextraParams, void *vextraParams,
const Nd4jLong n) { const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(n);
if (xEws == 1 && yEws == 1 && zEws == 1) { if (xEws == 1 && yEws == 1 && zEws == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i] = OpType::op(x[i], y[i], extraParams);
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto yi = y + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], yi[i], extraParams);
}
} }
else { else {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams);
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto yi = y + yEws*threadOffset;
auto zi = z + zEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
}
} }
} }
@ -112,14 +93,16 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams) { void *extraParams,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
yShapeInfo, yShapeInfo,
z, z,
zShapeInfo, zShapeInfo,
extraParams), extraParams, start, stop),
PAIRWISE_BOOL_OPS); PAIRWISE_BOOL_OPS);
}; };
@ -129,7 +112,9 @@ namespace functions {
void PairWiseBoolTransform<X, Z>::exec(void *vx, Nd4jLong* xShapeInfo, void PairWiseBoolTransform<X, Z>::exec(void *vx, Nd4jLong* xShapeInfo,
void *vy, Nd4jLong* yShapeInfo, void *vy, Nd4jLong* yShapeInfo,
void *vz, Nd4jLong* zShapeInfo, void *vz, Nd4jLong* zShapeInfo,
void *vextraParams) { void *vextraParams,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
@ -141,8 +126,6 @@ namespace functions {
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo);
nd4j::OmpLaunchHelper info(n);
if (shape::isScalar(yShapeInfo)) { if (shape::isScalar(yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
@ -150,37 +133,22 @@ namespace functions {
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for(auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); z[offset] = OpType::op(x[offset], y[0], extraParams);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); };
PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[0], extraParams);
}
}
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for(auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
};
PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
}
}
} }
return; return;
} }
@ -189,96 +157,62 @@ namespace functions {
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop);
} }
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop);
} }
else { else {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); z[offset] = OpType::op(x[offset], y[offset], extraParams);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); };
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[offset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
};
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
};
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
};
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
}
}
} }
else { else {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
@ -286,20 +220,13 @@ namespace functions {
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
PRAGMA_OMP_SIMD };
for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
}
} }
} }
} }

View File

@ -22,6 +22,7 @@
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <OmpLaunchHelper.h> #include <OmpLaunchHelper.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -38,7 +39,9 @@ namespace functions {
void *z, void *z,
Nd4jLong zEws, Nd4jLong zEws,
void *extraParams, void *extraParams,
Nd4jLong n) { Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xEws, xEws,
y, y,
@ -46,7 +49,7 @@ namespace functions {
z, z,
zEws, zEws,
extraParams, extraParams,
n), PAIRWISE_INT_OPS); n, start, stop), PAIRWISE_INT_OPS);
}; };
@ -60,46 +63,24 @@ namespace functions {
void *vz, void *vz,
Nd4jLong zEws, Nd4jLong zEws,
void *vextraParams, void *vextraParams,
const Nd4jLong n) { const Nd4jLong n,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(n);
if (xEws == 1 && yEws == 1 && zEws == 1) { if (xEws == 1 && yEws == 1 && zEws == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i] = OpType::op(x[i], y[i], extraParams);
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto yi = y + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], yi[i], extraParams);
}
} }
else { else {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams);
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto yi = y + yEws*threadOffset;
auto zi = z + zEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
}
} }
} }
@ -112,14 +93,16 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams) { void *extraParams,
const uint64_t start,
const uint64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
y, y,
yShapeInfo, yShapeInfo,
z, z,
zShapeInfo, zShapeInfo,
extraParams), extraParams, start, stop),
PAIRWISE_INT_OPS); PAIRWISE_INT_OPS);
}; };
@ -129,7 +112,9 @@ namespace functions {
void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo, void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo,
void *vy, Nd4jLong* yShapeInfo, void *vy, Nd4jLong* yShapeInfo,
void *vz, Nd4jLong* zShapeInfo, void *vz, Nd4jLong* zShapeInfo,
void *vextraParams) { void *vextraParams,
const uint64_t start,
const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
@ -141,46 +126,28 @@ namespace functions {
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo);
nd4j::OmpLaunchHelper info(n);
if (shape::isScalar(yShapeInfo)) { if (shape::isScalar(yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for(auto i = start; i < stop; i++) {
{ auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadNum = omp_get_thread_num(); z[offset] = OpType::op(x[offset], y[0], extraParams);
auto threadOffset = info.getThreadOffset(threadNum); };
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[0], extraParams);
}
}
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for(auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
};
PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
}
}
} }
return; return;
} }
@ -189,96 +156,63 @@ namespace functions {
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop);
} }
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop);
} }
else { else {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); z[offset] = OpType::op(x[offset], y[offset], extraParams);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); };
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[offset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
};
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
};
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
}
}
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
};
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
}
}
} }
else { else {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
@ -286,20 +220,13 @@ namespace functions {
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
PRAGMA_OMP_SIMD };
for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
}
} }
} }
} }

View File

@ -52,28 +52,22 @@ namespace functions {
auto length = shape::length(zShapeInfo); auto length = shape::length(zShapeInfo);
// nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state);
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
nd4j::OmpLaunchHelper info(length);
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (auto i = start; i < stop; i += increment) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
@ -82,19 +76,16 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (uint64_t i = start; i < stop; i += increment) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
@ -103,19 +94,16 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (uint64_t i = start; i < stop; i += increment) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
@ -124,19 +112,16 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < info.getItersPerThread(threadNum); i++) { for (uint64_t i = start; i < stop; i += increment) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY); auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
else { else {
@ -147,20 +132,17 @@ namespace functions {
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (uint64_t i = start; i < stop; i += increment) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
}; };
@ -184,41 +166,34 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
nd4j::OmpLaunchHelper info(length);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (uint64_t i = start; i < stop; i += increment) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (uint64_t i = start; i < stop; i += increment) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
} }
@ -232,25 +207,21 @@ namespace functions {
auto length = shape::length(zShapeInfo); auto length = shape::length(zShapeInfo);
//nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state);
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
nd4j::OmpLaunchHelper info(length); nd4j::OmpLaunchHelper info(length);
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) auto func = PRAGMA_THREADS_FOR {
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (uint64_t i = start; i < stop; i += increment) {
auto offset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ); auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments); z[offset] = OpClass::op(i, length, rng, extraArguments);
} }
} };
samediff::Threads::parallel_for(func, 0, length, 1);
} }
template<typename X> template<typename X>

View File

@ -55,7 +55,7 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++) for (uint i = 0; i < length; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
@ -65,25 +65,14 @@ namespace functions {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
} }
else { else {
X start = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
X intermediate[256];
for (int e = 0; e < maxThreads; e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) for (auto i = 0; i < length; i++)
for(Nd4jLong i = 0; i < length; ++i) startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
z[0] = OpType::postProcess(startingValue, length, extraParams);
for (int e = 0; e < maxThreads; e++)
start = OpType::update(start, intermediate[e], extraParams);
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
} }
} }
@ -102,23 +91,14 @@ namespace functions {
return execScalar<OpType>(x, xEws, length, extraParams); return execScalar<OpType>(x, xEws, length, extraParams);
} }
else { else {
X start = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
auto intermediate = new X[nd4j::math::nd4j_max<int>(1, omp_get_max_threads())];
for (int e = 0; e < omp_get_max_threads(); e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_SIMD for (auto i = 0; i < length; i++)
for(Nd4jLong i = 0; i < length; ++i) startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < omp_get_max_threads(); e++) return OpType::postProcess(startingValue, length, extraParams);
start = OpType::update(start, intermediate[e], extraParams);
delete[] intermediate;
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
} }
} }
@ -150,8 +130,8 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), REDUCE_BOOL_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
@ -164,7 +144,7 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vresult); auto z = reinterpret_cast<Z *>(vresult);
@ -176,7 +156,7 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++) for (uint i = 0; i < resultLength; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
@ -205,9 +185,9 @@ namespace functions {
} }
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#else #else
nd4j::ReductionBoolLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionBoolLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
@ -227,49 +207,33 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
Z intermediate[64];
auto startingVal = OpType::startingValue(x); PRAGMA_OMP_SIMD
nd4j::OmpLaunchHelper info(length); for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x);
if (xEws == 1) { auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
auto local = OpType::startingValue(x); } else {
auto threadNum = omp_get_thread_num(); for (auto i = start; i < stop; i++)
auto threadOffset = info.getThreadOffset(threadNum); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++) {
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
}
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
} }
} };
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++) // merge results
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams); for (int e = 1; e < maxThreads; e++)
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
PRAGMA_OMP_CRITICAL // return result
startingVal = OpType::update(startingVal, local, extraParams); return OpType::postProcess(intermediate[0], length, extraParams);
}
}
return OpType::postProcess(startingVal, length, extraParams);
} }

View File

@ -59,9 +59,10 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++) for (uint i = 0; i < length; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
} }
@ -69,25 +70,29 @@ namespace functions {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
} }
else { else {
X start = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
X intermediate[256];
for (int e = 0; e < maxThreads; e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
Z intermediate[64];
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < length; ++i) for (auto e = 0; e < maxThreads; e++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); intermediate[e] = OpType::startingValue(x);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++)
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
};
for (int e = 0; e < maxThreads; e++) maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
start = OpType::update(start, intermediate[e], extraParams);
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams); // merge results
for (int e = 1; e < maxThreads; e++)
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
// write out results
z[0] = OpType::postProcess(intermediate[0], length, extraParams);
} }
} }
@ -105,23 +110,14 @@ namespace functions {
return execScalar<OpType>(x, xEws, length, extraParams); return execScalar<OpType>(x, xEws, length, extraParams);
} }
else { else {
X start = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
auto intermediate = new X[nd4j::math::nd4j_max<int>(1, omp_get_max_threads())];
for (int e = 0; e < omp_get_max_threads(); e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_SIMD for (auto i = 0; i < length; i++)
for(Nd4jLong i = 0; i < length; ++i) startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < omp_get_max_threads(); e++) return OpType::postProcess(startingValue, length, extraParams);
start = OpType::update(start, intermediate[e], extraParams);
delete[] intermediate;
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
} }
} }
@ -153,7 +149,7 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
extraParams, extraParams,
@ -162,7 +158,7 @@ namespace functions {
dimension, dimension,
dimensionLength, dimensionLength,
tadShapeInfo, tadShapeInfo,
tadOffset), tadOffset, start, stop),
REDUCE_FLOAT_OPS); REDUCE_FLOAT_OPS);
} }
@ -176,7 +172,7 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vresult); auto z = reinterpret_cast<Z *>(vresult);
@ -188,7 +184,7 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x)); const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++) for (uint i = 0; i < resultLength; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
@ -222,9 +218,9 @@ namespace functions {
} }
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
nd4j::ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#else #else
nd4j::ReductionFloatLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionFloatLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
@ -245,49 +241,34 @@ namespace functions {
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
Z intermediate[64];
auto startingVal = OpType::startingValue(x); PRAGMA_OMP_SIMD
nd4j::OmpLaunchHelper info(length); for (auto e = 0; e < maxThreads; e++)
int nt = info._numThreads; intermediate[e] = OpType::startingValue(x);
if (xEws == 1) { auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
auto local = OpType::startingValue(x); } else {
auto threadNum = omp_get_thread_num(); for (auto i = start; i < stop; i++)
auto threadOffset = info.getThreadOffset(threadNum); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++)
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
} }
} };
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++) // merge results
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams); for (int e = 1; e < maxThreads; e++)
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
PRAGMA_OMP_CRITICAL // return result
startingVal = OpType::update(startingVal, local, extraParams); return OpType::postProcess(intermediate[0], length, extraParams);
} }
}
return OpType::postProcess(startingVal, length, extraParams);
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES);

View File

@ -55,7 +55,7 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++) for (uint i = 0; i < length; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
@ -65,25 +65,29 @@ namespace functions {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
} }
else { else {
X start = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
X intermediate[256];
for (int e = 0; e < maxThreads; e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
Z intermediate[64];
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < length; ++i) for (auto e = 0; e < maxThreads; e++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); intermediate[e] = OpType::startingValue(x);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++)
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
};
for (int e = 0; e < maxThreads; e++) maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
start = OpType::update(start, intermediate[e], extraParams);
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams); // merge results
for (int e = 1; e < maxThreads; e++)
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
// write out results
z[0] = OpType::postProcess(intermediate[0], length, extraParams);
} }
} }
@ -103,23 +107,14 @@ namespace functions {
return execScalar<OpType>(x, xEws, length, extraParams); return execScalar<OpType>(x, xEws, length, extraParams);
} }
else { else {
X start = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
auto intermediate = new X[nd4j::math::nd4j_max<int>(1, omp_get_max_threads())];
for (int e = 0; e < omp_get_max_threads(); e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_SIMD for (auto i = 0; i < length; i++)
for(Nd4jLong i = 0; i < length; ++i) startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < omp_get_max_threads(); e++) return OpType::postProcess(startingValue, length, extraParams);
start = OpType::update(start, intermediate[e], extraParams);
delete[] intermediate;
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
} }
} }
@ -152,8 +147,8 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), REDUCE_LONG_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_LONG_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
@ -166,7 +161,7 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vresult); auto z = reinterpret_cast<Z *>(vresult);
@ -178,7 +173,7 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++) for (uint i = 0; i < resultLength; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
@ -212,9 +207,9 @@ namespace functions {
} }
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#else #else
nd4j::ReductionLongLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionLongLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
@ -235,48 +230,34 @@ namespace functions {
template <typename OpType> template <typename OpType>
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
Z intermediate[64];
auto startingVal = OpType::startingValue(x); PRAGMA_OMP_SIMD
nd4j::OmpLaunchHelper info(length); for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x);
auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) { if (xEws == 1) {
for (auto i = start; i < stop; i++)
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
{ } else {
auto local = OpType::startingValue(x); for (auto i = start; i < stop; i++)
auto threadNum = omp_get_thread_num(); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++)
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
}
} }
else { };
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++) // merge results
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams); for (int e = 1; e < maxThreads; e++)
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
PRAGMA_OMP_CRITICAL // return result
startingVal = OpType::update(startingVal, local, extraParams); return OpType::postProcess(intermediate[0], length, extraParams);
} }
}
return OpType::postProcess(startingVal, length, extraParams);
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES);

View File

@ -57,7 +57,7 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++) for (uint i = 0; i < length; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
@ -67,25 +67,29 @@ namespace functions {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
} }
else { else {
X start = OpType::startingValue(x); auto startingValue = OpType::startingValue(x);
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
X intermediate[256];
for (int e = 0; e < maxThreads; e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
X intermediate[64];
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < length; ++i) for (auto e = 0; e < maxThreads; e++)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams); intermediate[e] = OpType::startingValue(x);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++)
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
};
for (int e = 0; e < maxThreads; e++) maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
start = OpType::update(start, intermediate[e], extraParams);
z[0] = OpType::postProcess(start, length, extraParams); // merge results
for (int e = 1; e < maxThreads; e++)
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
// write out results
z[0] = OpType::postProcess(intermediate[0], length, extraParams);
} }
} }
@ -103,26 +107,15 @@ namespace functions {
if (xEws >= 1) { if (xEws >= 1) {
return execScalar<OpType>(x, xEws, length, extraParams); return execScalar<OpType>(x, xEws, length, extraParams);
} } else {
else { auto startingValue = OpType::startingValue(x);
X start = OpType::startingValue(x);
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
X intermediate[256];
for (int e = 0; e < maxThreads; e++)
intermediate[e] = start;
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) for (auto i = 0; i < length; i++)
for(Nd4jLong i = 0; i < length; ++i) startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < maxThreads; e++) return OpType::postProcess(startingValue, length, extraParams);
start = OpType::update(start, intermediate[e], extraParams);
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
} }
} }
@ -154,7 +147,7 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo, xShapeInfo,
extraParams, extraParams,
@ -163,7 +156,7 @@ namespace functions {
dimension, dimension,
dimensionLength, dimensionLength,
tadShapeInfo, tadShapeInfo,
tadOffset), tadOffset, start, stop),
REDUCE_SAME_OPS); REDUCE_SAME_OPS);
} }
@ -177,7 +170,7 @@ namespace functions {
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset) { Nd4jLong *tadOffset, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
@ -189,7 +182,7 @@ namespace functions {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(zLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < zLength; i++) for (uint i = 0; i < zLength; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
@ -223,9 +216,9 @@ namespace functions {
} }
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
nd4j::ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#else #else
nd4j::ReductionSameLoops<X>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); nd4j::ReductionSameLoops<X>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
#endif #endif
} }
@ -246,48 +239,34 @@ namespace functions {
template <typename OpType> template <typename OpType>
X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
X intermediate[64];
auto startingVal = OpType::startingValue(x); PRAGMA_OMP_SIMD
nd4j::OmpLaunchHelper info(length); for (auto e = 0; e < maxThreads; e++)
intermediate[e] = OpType::startingValue(x);
auto func = PRAGMA_THREADS_FOR {
if (xEws == 1) { if (xEws == 1) {
for (auto i = start; i < stop; i++)
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
{ } else {
auto local = OpType::startingValue(x); for (auto i = start; i < stop; i++)
auto threadNum = omp_get_thread_num(); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++)
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
}
} }
else { };
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++) // merge results
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams); for (int e = 1; e < maxThreads; e++)
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
PRAGMA_OMP_CRITICAL // return result
startingVal = OpType::update(startingVal, local, extraParams); return OpType::postProcess(intermediate[0], length, extraParams);
} }
}
return OpType::postProcess(startingVal, length, extraParams);
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES);

View File

@ -24,6 +24,7 @@
#include <loops/legacy_ops.h> #include <loops/legacy_ops.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <Loops.h> #include <Loops.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -51,72 +52,82 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return; return;
const auto startingVal = OpType::startingValue(x); const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++) for (uint i = 0; i < length; i++)
z[i] = startingVal; z[i] = startingVal;
return; return;
} }
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f}; Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
// it's possible case for EqualsWithEps op
if (extraParams != nullptr)
extraParamsVals[2] = extraParams[0];
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
Z startingVal = OpType::startingValue(x); Z startingVal = OpType::startingValue(x);
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads()); int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
nd4j::OmpLaunchHelper t(length, maxThreads); Z intermediate[64];
Z intermediate[256]; Z extraParamsLocal[3 * 64];
Z extraParamsLocal[3 * 256];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
intermediate[e] = startingVal; intermediate[e] = startingVal;
memset(extraParamsLocal, 0, 3 * 256 * sizeof(Z)); memset(extraParamsLocal, 0, 3 * 64 * sizeof(Z));
if (extraParams != nullptr) if (extraParams != nullptr) {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int e = 0; e < maxThreads; e++) // mostly for future reference
extraParamsLocal[3 * e + 2] = extraParams[0]; for (int e = 0; e < maxThreads; e++) {
extraParamsLocal[3 * e] = extraParams[0];
extraParamsLocal[3 * e + 1] = extraParams[1];
extraParamsLocal[3 * e + 2] = extraParams[2];
}
}
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo); nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads) auto func = PRAGMA_THREADS_FOR {
for(unsigned int i = 0; i < length; i++) { for (auto i = start; i < stop; i += increment) {
const auto threadNum = omp_get_thread_num(); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[i], y[i], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum); }
} };
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { } else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads) auto func = PRAGMA_THREADS_FOR {
for(unsigned int i = 0; i < length; i++) { for (auto i = start; i < stop; i += increment) {
const auto threadNum = omp_get_thread_num(); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum); }
} };
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} else { } else {
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads) auto func = PRAGMA_THREADS_FOR {
for(unsigned int i = 0; i < length; i++) { for (auto i = start; i < stop; i += increment) {
const auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY); intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum); }
} };
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
} }
// merge step // merge step
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e); OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals); startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals);
// writing out result
z[0] = OpType::postProcess(startingVal, length, extraParamsVals); z[0] = OpType::postProcess(startingVal, length, extraParamsVals);
} }
@ -139,7 +150,7 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength) { int *dimension, int dimensionLength, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X*>(vy);
@ -151,9 +162,9 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
return; return;
} }
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams); nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
#else #else
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams); nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
#endif #endif
} }
@ -165,16 +176,16 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams); nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
#else #else
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams); nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
#endif #endif
} }
@ -188,7 +199,7 @@ void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
@ -196,9 +207,9 @@ void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
auto extraParams = reinterpret_cast<Z*>(vextraParams); auto extraParams = reinterpret_cast<Z*>(vextraParams);
#ifdef INLINE_LOOPS #ifdef INLINE_LOOPS
nd4j::Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams); nd4j::Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop);
#else #else
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams); nd4j::Reduction3Loops<X,Z>::template innerloopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop);
#endif #endif
} }
@ -209,9 +220,9 @@ void Reduce3<X,Y>::exec( const int opNum,
void *extraParamsVals, void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength) { int *dimension, int dimensionLength, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, start, stop), REDUCE3_OPS);
} }
@ -223,9 +234,9 @@ void Reduce3<X,Y>::exec( const int opNum,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), REDUCE3_OPS);
} }
@ -238,9 +249,9 @@ void Reduce3<X,Y>::execAll(const int opNum,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) {
DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), REDUCE3_OPS); DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS);
} }

View File

@ -22,6 +22,7 @@
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <execution/Threads.h>
#include "../legacy_ops.h" #include "../legacy_ops.h"
using namespace simdOps; using namespace simdOps;
@ -39,7 +40,8 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
@ -63,29 +65,27 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
return; return;
} }
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads()); int num_threads = nd4j::math::nd4j_min<int>(numTads, nd4j::Environment::getInstance()->maxThreads());
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (auto r = start; r < stop; r++) {
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r]; auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r]; auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], scalars[r], extraParams); oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
} };
} }
else { else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (auto r = start; r < stop; r++) {
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r]; auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r]; auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
} };
} }
} }
@ -98,9 +98,10 @@ void ScalarTransform<X,Y,Z>::transform(int opNum,
void *scalars, void *scalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_OPS); DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -110,9 +111,10 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
void *z, Nd4jLong zStride, void *z, Nd4jLong zStride,
void *scalar, void *scalar,
void *extraParams, void *extraParams,
const Nd4jLong n, bool allowParallelism) { const uint64_t n,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, allowParallelism), SCALAR_OPS); DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), SCALAR_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -121,9 +123,10 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong *xShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong *zShapeInfo,
void *scalar, void *scalar,
void *extraParams, bool allowParallelism) { void *extraParams,
const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, allowParallelism), SCALAR_OPS); DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_OPS);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -132,7 +135,8 @@ template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vscalar, void *vscalar,
void *vextraParams, bool allowParallelism) { void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
@ -146,48 +150,30 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) { if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, allowParallelism); transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, start, stop);
} }
else { else {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
nd4j::OmpLaunchHelper info(len, allowParallelism ? -1 : 1);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism) for (auto i = start; i < stop; i++) {
{ auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadNum = omp_get_thread_num(); z[offset] = OpType::op(x[offset], scalar, extraParams);
auto threadOffset = info.getThreadOffset(threadNum); };
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], scalar, extraParams);
}
}
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
};
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
}
}
} }
} }
} }
@ -199,44 +185,22 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong xEws,
void *vz, Nd4jLong zEws, void *vz, Nd4jLong zEws,
void *vscalar, void *vscalar,
void *vextraParams, void *vextraParams,
const Nd4jLong len, bool allowParallelism) { const uint64_t len, const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<Y *>(vscalar)[0]; auto scalar = reinterpret_cast<Y *>(vscalar)[0];
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
nd4j::OmpLaunchHelper info(len, allowParallelism ? -1 : 1);
if (xEws == 1 && zEws == 1) { if (xEws == 1 && zEws == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism) for (auto i = start; i < stop; i++)
{ z[i] = OpType::op(x[i], scalar, extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], scalar, extraParams);
}
} }
else { else {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism) for (auto i = start; i < stop; i++)
{ z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws * threadOffset;
auto zi = z + zEws * threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
}
} }
} }

View File

@ -22,6 +22,7 @@
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <execution/Threads.h>
#include "../legacy_ops.h" #include "../legacy_ops.h"
@ -39,7 +40,8 @@ namespace functions {
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
@ -64,29 +66,27 @@ namespace functions {
return; return;
} }
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads()); int num_threads = nd4j::math::nd4j_min<int>(numTads, nd4j::Environment::getInstance()->maxThreads());
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (auto r = start; r < stop; r++) {
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r]; auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r]; auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], scalars[r], extraParams); oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
} };
} }
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (auto r = start; r < stop; r++) {
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r]; auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r]; auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
} };
} }
} }
@ -103,8 +103,8 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffsets, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets) { Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_BOOL_OPS); DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_BOOL_OPS);
} }
@ -116,8 +116,9 @@ namespace functions {
Nd4jLong zEws, Nd4jLong zEws,
void *scalar, void *scalar,
void *extraParams, void *extraParams,
const Nd4jLong n) { const uint64_t n,
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_BOOL_OPS); const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_BOOL_OPS);
} }
template<typename X, typename Y> template<typename X, typename Y>
@ -127,8 +128,9 @@ namespace functions {
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *scalar, void *scalar,
void *extraParams) { void *extraParams,
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_BOOL_OPS); const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_BOOL_OPS);
} }
template<typename X, typename Z> template<typename X, typename Z>
@ -138,7 +140,8 @@ namespace functions {
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vscalar, void *vscalar,
void *vextraParams) { void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
@ -149,53 +152,33 @@ namespace functions {
auto zEws = shape::elementWiseStride(zShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo);
auto len = shape::length(xShapeInfo); auto len = shape::length(xShapeInfo);
// nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws);
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) { if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len); transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, start, stop);
return; return;
} }
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
nd4j::OmpLaunchHelper info(len);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++) {
{ auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadNum = omp_get_thread_num(); z[offset] = OpType::op(x[offset], scalar, extraParams);
auto threadOffset = info.getThreadOffset(threadNum); };
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], scalar, extraParams);
}
}
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
};
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
}
}
} }
} }
@ -208,44 +191,23 @@ namespace functions {
Nd4jLong zEws, Nd4jLong zEws,
void *vscalar, void *vscalar,
void *vextraParams, void *vextraParams,
const Nd4jLong len) { const uint64_t len,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(len);
if (xEws == 1 && zEws == 1) { if (xEws == 1 && zEws == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i] = OpType::op(x[i], scalar, extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], scalar, extraParams);
}
} }
else { else {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws * threadOffset;
auto zi = z + zEws * threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
}
} }
} }

View File

@ -22,6 +22,7 @@
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <types/types.h> #include <types/types.h>
#include <LoopKind.h> #include <LoopKind.h>
#include <execution/Threads.h>
#include "../legacy_ops.h" #include "../legacy_ops.h"
@ -39,7 +40,8 @@ namespace functions {
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
@ -64,29 +66,27 @@ namespace functions {
return; return;
} }
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads()); int num_threads = nd4j::math::nd4j_min<int>(numTads, nd4j::Environment::getInstance()->maxThreads());
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (auto r = start; r < stop; r++) {
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r]; auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r]; auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], scalars[r], extraParams); oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
} };
} }
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO else {
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) for (auto r = start; r < stop; r++) {
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r]; auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r]; auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
} };
} }
} }
@ -103,8 +103,10 @@ namespace functions {
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffsets, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets) { Nd4jLong *zTadOffsets,
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_INT_OPS); const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_INT_OPS);
} }
@ -116,8 +118,9 @@ namespace functions {
Nd4jLong zEws, Nd4jLong zEws,
void *scalar, void *scalar,
void *extraParams, void *extraParams,
const Nd4jLong n) { const uint64_t n,
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_INT_OPS); const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_INT_OPS);
} }
template<typename X> template<typename X>
@ -127,8 +130,9 @@ namespace functions {
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *scalar, void *scalar,
void *extraParams) { void *extraParams,
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_INT_OPS); const uint64_t start, const uint64_t stop) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_INT_OPS);
} }
template<typename X> template<typename X>
@ -138,7 +142,8 @@ namespace functions {
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vscalar, void *vscalar,
void *vextraParams) { void *vextraParams,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
@ -149,53 +154,33 @@ namespace functions {
auto zEws = shape::elementWiseStride(zShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo);
auto len = shape::length(xShapeInfo); auto len = shape::length(xShapeInfo);
// nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws);
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) { if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len); transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, start, stop);
return; return;
} }
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
nd4j::OmpLaunchHelper info(len);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++) {
{ auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadNum = omp_get_thread_num(); z[offset] = OpType::op(x[offset], scalar, extraParams);
auto threadOffset = info.getThreadOffset(threadNum); };
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], scalar, extraParams);
}
}
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_SIMD
{ for (auto i = start; i < stop; i++) {
auto threadNum = omp_get_thread_num(); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto threadOffset = info.getThreadOffset(threadNum); auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
};
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
}
}
} }
} }
@ -208,44 +193,23 @@ namespace functions {
Nd4jLong zEws, Nd4jLong zEws,
void *vscalar, void *vscalar,
void *vextraParams, void *vextraParams,
const Nd4jLong len) { const uint64_t len,
const uint64_t start, const uint64_t stop) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(len);
if (xEws == 1 && zEws == 1) { if (xEws == 1 && zEws == 1) {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i] = OpType::op(x[i], scalar, extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], scalar, extraParams);
}
} }
else { else {
PRAGMA_OMP_SIMD
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) for (auto i = start; i < stop; i++)
{ z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams);
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws * threadOffset;
auto zi = z + zEws * threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
}
} }
} }

View File

@ -24,6 +24,7 @@
#include <helpers/shape.h> #include <helpers/shape.h>
#include <helpers/TAD.h> #include <helpers/TAD.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -90,8 +91,7 @@ namespace functions {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCast = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast); const bool canCast = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
for (Nd4jLong i = 0; i < length; i++) { for (uint64_t i = 0; i < length; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast);
SummaryStatsData<X> curr; SummaryStatsData<X> curr;
@ -123,7 +123,7 @@ namespace functions {
return; return;
SummaryStatsData<X> comp; SummaryStatsData<X> comp;
comp.initWithValue(x[0]); comp.initWithValue(x[0]);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++) for (uint i = 0; i < resultLength; i++)
z[i] = OpType::getValue(biasCorrected, comp); z[i] = OpType::getValue(biasCorrected, comp);
return; return;
@ -157,35 +157,37 @@ namespace functions {
uint tadShapeShapeInfoCast[MAX_RANK]; uint tadShapeShapeInfoCast[MAX_RANK];
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast); const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR auto func = PRAGMA_THREADS_FOR {
for (int r = 0; r < resultLength; r++) { for (auto r = start; r < stop; r += increment) {
auto tadOffsetForBlock = tadPack.primaryOffsets()[r]; auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
auto tx = x + tadOffsetForBlock; auto tx = x + tadOffsetForBlock;
SummaryStatsData<X> comp; SummaryStatsData <X> comp;
comp.initWithValue(tx[0]); comp.initWithValue(tx[0]);
if (tadEWS == 1 && tadOrder == 'c') { if (tadEWS == 1 && tadOrder == 'c') {
for (int i = 1; i < tadLength; i ++) { for (int i = 1; i < tadLength; i++) {
SummaryStatsData <X> indexVal2; SummaryStatsData <X> indexVal2;
indexVal2.initWithValue(tx[i]); indexVal2.initWithValue(tx[i]);
comp = update(comp, OpType::op(indexVal2, extraParams), extraParams); comp = update(comp, OpType::op(indexVal2, extraParams), extraParams);
}
} else {
for (int i = 1; i < tadLength; i++) {
auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, canCast);
SummaryStatsData <X> indexVal2;
indexVal2.initWithValue(tx[xOffset]);
comp = update(comp, OpType::op(indexVal2, extraParams), extraParams);
}
} }
z[r] = OpType::getValue(biasCorrected, comp);
} }
else { };
for (int i = 1; i < tadLength; i ++) {
auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, canCast);
SummaryStatsData <X> indexVal2; samediff::Threads::parallel_tad(func, 0, resultLength, 1);
indexVal2.initWithValue(tx[xOffset]);
comp = update(comp, OpType::op(indexVal2, extraParams), extraParams);
}
}
z[r] = OpType::getValue(biasCorrected, comp);
}
} }

View File

@ -37,9 +37,8 @@ namespace functions {
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadOffsets, bool allowParallelism) { DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_ANY_OPS);
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets, allowParallelism), TRANSFORM_ANY_OPS);
} }
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
@ -47,22 +46,13 @@ template <typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo, void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vz,Nd4jLong *zShapeInfo, void *vz,Nd4jLong *zShapeInfo,
void *vextraParams, void *vextraParams, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadShapeInfo,Nd4jLong *tadOffsets, bool allowParallelism) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if(OpType::requiresSpecial) {
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
return;
}
if (allowParallelism) nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
else
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, false>(x, xShapeInfo, z, zShapeInfo, extraParams);
} }

View File

@ -37,9 +37,8 @@ namespace functions {
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadOffsets) { DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_BOOL_OPS);
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_BOOL_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
@ -49,20 +48,13 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vextraParams, void *vextraParams, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if(OpType::requiresSpecial) { nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
return;
}
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES);

View File

@ -36,9 +36,8 @@ namespace functions {
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams, void *extraParams,
Nd4jLong *tadShapeInfo, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadOffsets) { DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_FLOAT_OPS);
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_FLOAT_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
@ -48,20 +47,13 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vextraParams, void *vextraParams, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams); auto extraParams = reinterpret_cast<Z *>(vextraParams);
if(OpType::requiresSpecial) { nd4j::TransformLoops<X,Z,Z>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
return;
}
nd4j::TransformLoops<X,Z,Z>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES);

View File

@ -36,10 +36,8 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams, void *extraParams, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadShapeInfo, DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_SAME_OPS);
Nd4jLong *tadOffsets) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_SAME_OPS);
} }
template <typename X> template <typename X>
@ -47,18 +45,14 @@ namespace functions {
void _CUDA_H TransformSame<X>::exec(void *vx, Nd4jLong *xShapeInfo, void _CUDA_H TransformSame<X>::exec(void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vextraParams, void *vextraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { uint64_t threadId, uint64_t numThreads) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if(OpType::requiresSpecial) {
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
return;
}
nd4j::TransformLoops<X,X,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams); nd4j::TransformLoops<X,X,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
} }
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES);

View File

@ -36,10 +36,8 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams, void *extraParams, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadShapeInfo, DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_STRICT_OPS);
Nd4jLong *tadOffsets) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_STRICT_OPS);
} }
template <typename X> template <typename X>
@ -49,20 +47,13 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vextraParams, void *vextraParams, uint64_t threadId, uint64_t numThreads) {
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffsets) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
if(OpType::requiresSpecial) { nd4j::TransformLoops<X,X,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
return;
}
nd4j::TransformLoops<X,X,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
} }
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES);

View File

@ -1,145 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
// @author Yurii Shyrma, created on 27.11.2018
//
#include "../aggregates.h"
namespace functions {
namespace aggregate {
///////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpClass>
__device__ void AggregatedFunction<X>::execCuda(X **arguments, int numArguments,
Nd4jLong **shapeArguments, int numShapeArguments,
int *indexArguments, int numIndexArguments,
int **intArrays, int numIntArrays,
X *realArguments, int numRealArguments) {
OpClass::executeAggregateCuda(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
}
///////////////////////////////////////////////////////////////////////
template <typename X>
__device__ void AggregatedFunction<X>::execCuda(int opNum,
X **arguments, int numArguments,
Nd4jLong **shapeArguments, int numShapeArguments,
int *indexArguments, int numIndexArguments,
int **intArrays, int numIntArrays,
X *realArguments, int numRealArguments) {
DISPATCH_BY_OPNUM_T(execCuda, PARAMS(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), AGGREGATE_OPS);
}
///////////////////////////////////////////////////////////////////////
template <typename X>
__global__ static void execAggregateKernel(int opNum,
void **varguments, int numArguments,
Nd4jLong **shapeArguments, int numShapeArguments,
int *indexArguments, int numIndexArguments,
int **intArrays, int numIntArrays,
void *vrealArguments, int numRealArguments) {
auto arguments = reinterpret_cast<X**>(varguments);
auto realArguments = reinterpret_cast<X*>(vrealArguments);
functions::aggregate::AggregatedFunction<X>::execCuda(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
}
///////////////////////////////////////////////////////////////////////
template <typename X>
__host__ void AggregatedFunction<X>::aggregateKernelGeneric(dim3& launchDims, cudaStream_t *stream,
int opNum,
void **arguments, int numArguments,
Nd4jLong **shapeArguments, int numShapeArguments,
int *indexArguments, int numIndexArguments,
int **intArrays, int numIntArrays,
void *realArguments, int numRealArguments) {
execAggregateKernel<X><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
nd4j::DebugHelper::checkErrorCode(stream, "aggregateKernelGeneric(...) failed");
}
///////////////////////////////////////////////////////////////////////
template <typename X>
__device__ void AggregatedFunction<X>::aggregateBatch(int opNum, int numAggregates,
int maxArgs, int maxShapes,
int maxIntArrays, int maxIntArraySize,
int maxIdx, int maxReals,
void *ptrToArguments) {
nd4j::PointersHelper<X> helper(ptrToArguments, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals);
// TODO: we probably should lift this restriction
__shared__ int *intArrays[32];
__shared__ X **arguments;
__shared__ Nd4jLong **shapes;
__shared__ int *idxArg;
__shared__ X *realArg;
for(int r = blockIdx.x; r < numAggregates; r += gridDim.x) {
if (threadIdx.x == 0) {
arguments = helper.getArguments(r);
shapes = helper.getShapeArguments(r);
idxArg = helper.getIndexArguments(r);
realArg = helper.getRealArguments(r);
}
// we fill intArrays param in parallel within block
if (threadIdx.x < 32 && threadIdx.x < maxIntArrays) {
intArrays[threadIdx.x] = helper.getIntArrayArguments(r, threadIdx.x);
}
__syncthreads();
functions::aggregate::AggregatedFunction<X>::execCuda(opNum, arguments, helper.getNumArguments(r), shapes, helper.getNumShapeArguments(r), idxArg, helper.getNumIndexArguments(r), intArrays, helper.getNumIntArrayArguments(r), realArg, helper.getNumRealArguments(r));
}
}
///////////////////////////////////////////////////////////////////////
template <typename X>
__global__ static void execAggregateBatch(int opNum, int numAggregates,
int maxArgs, int maxShapes,
int maxIntArrays, int maxIntArraySize,
int maxIdx, int maxReals,
void *ptrToArguments) {
functions::aggregate::AggregatedFunction<X>::aggregateBatch(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments);
}
///////////////////////////////////////////////////////////////////////
template <typename X>
__host__ void AggregatedFunction<X>::aggregateBatchKernelGeneric(dim3& launchDims, cudaStream_t *stream,
int opNum, int numAggregates,
int maxArgs, int maxShapes,
int maxIntArrays, int maxIntArraySize,
int maxIdx, int maxReals,
void *ptrToArguments) {
execAggregateBatch<X><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments);
nd4j::DebugHelper::checkErrorCode(stream, "aggregateBatchKernel(...) failed");
}
BUILD_SINGLE_TEMPLATE(template class AggregatedFunction, , FLOAT_TYPES);
}
}

View File

@ -32,84 +32,6 @@
namespace functions { namespace functions {
namespace broadcast { namespace broadcast {
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
//
}
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
/**
* CPU execution
* @param x the input
* @param xShapeInfo the x shape information
* @param y the y data
* @param yShapeInfo the y shape information
* @param result the result
* @param resultShapeInfo the result shape information
* @param dimension the dimension to broadcast along long
* @param dimensionLength the length of the dimension buffer
*/
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
//
}
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
} }
} }

View File

@ -224,76 +224,6 @@ namespace functions {
} }
template<typename X, typename Y>
void BroadcastBool<X,Y>::exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X, typename Y>
void BroadcastBool<X,Y>::execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X, typename Y>
template<typename OpType>
void BroadcastBool<X,Y>::exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X, typename Y>
template<typename OpType>
void BroadcastBool<X,Y>::execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES);
} }
} }

View File

@ -217,75 +217,6 @@ namespace functions {
} }
} }
template<typename X>
void BroadcastInt<X>::exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X>
void BroadcastInt<X>::execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X>
template<typename OpType>
void BroadcastInt<X>::exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X>
template<typename OpType>
void BroadcastInt<X>::execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
} }
} }

View File

@ -359,32 +359,6 @@ namespace functions {
} }
} }
template <typename X, typename Z>
Nd4jLong IndexReduce<X,Z>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
return 0;
}
template <typename X, typename Z>
void IndexReduce<X,Z>::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
}
template <typename X, typename Z>
template<typename OpType>
Nd4jLong IndexReduce<X,Z>:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) {
return 0;
}
template <typename X, typename Z>
template<typename OpType>
_CUDA_H void IndexReduce<X,Z>::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES);
} }
} }

View File

@ -22,58 +22,6 @@
namespace functions { namespace functions {
namespace pairwise_transforms { namespace pairwise_transforms {
template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams) {
}
template <typename X, typename Y, typename Z>
void PairWiseTransform<X, Y, Z>::exec(
const int opNum,
void *x,
Nd4jLong xStride,
void *y,
Nd4jLong yStride,
void *z,
Nd4jLong resultStride,
void *extraParams,
Nd4jLong len) {
}
template <typename X, typename Y, typename Z>
template<typename OpType>
void PairWiseTransform<X, Y, Z>:: exec(
void *vx,
Nd4jLong* xShapeInfo,
void *vy,
Nd4jLong* yShapeInfo,
void *vresult,
Nd4jLong* zShapeInfo,
void *vextraParams) {
}
template <typename X, typename Y, typename Z>
template<typename OpType>
void PairWiseTransform<X, Y, Z>::exec(void *vx,
Nd4jLong xStride,
void *vy,
Nd4jLong yStride,
void *vresult,
Nd4jLong resultStride,
void *vextraParams,
const Nd4jLong len) {
}
} }
} }

View File

@ -110,63 +110,6 @@ void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_
DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS); DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS);
} }
template<typename X, typename Y>
void PairWiseBoolTransform<X,Y>::exec(
const int opNum,
void *dx,
Nd4jLong *xShapeBuffer,
void *y,
Nd4jLong *yShapeBuffer,
void *result,
Nd4jLong *resultShapeBuffer,
void *extraParams) {
}
template<typename X, typename Y>
void PairWiseBoolTransform<X,Y>::exec(
const int opNum,
void *dx,
Nd4jLong xStride,
void *y,
Nd4jLong yStride,
void *result,
Nd4jLong resultStride,
void *extraParams,
Nd4jLong n) {
}
template<typename X, typename Y>
template<typename OpType>
void PairWiseBoolTransform<X,Y>::exec(
void *vx,
Nd4jLong* xShapeBuffer,
void *vy,
Nd4jLong* yShapeBuffer,
void *vresult,
Nd4jLong* resultShapeBuffer,
void *vextraParams) {
}
template<typename X, typename Y>
template<typename OpType>
void PairWiseBoolTransform<X,Y>::exec(void *vx,
Nd4jLong xStride,
void *vy,
Nd4jLong yStride,
void *vresult,
Nd4jLong resultStride,
void *vextraParams,
const Nd4jLong n) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
} }
} }

View File

@ -109,63 +109,6 @@ void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS); DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS);
} }
template<typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *dx,
Nd4jLong *xShapeBuffer,
void *y,
Nd4jLong *yShapeBuffer,
void *result,
Nd4jLong *resultShapeBuffer,
void *extraParams) {
}
template<typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *dx,
Nd4jLong xStride,
void *y,
Nd4jLong yStride,
void *result,
Nd4jLong resultStride,
void *extraParams,
Nd4jLong n) {
}
template<typename X>
template<typename OpType>
void PairWiseIntTransform<X>::exec(
void *vx,
Nd4jLong* xShapeBuffer,
void *vy,
Nd4jLong* yShapeBuffer,
void *vresult,
Nd4jLong* resultShapeBuffer,
void *vextraParams) {
}
template<typename X>
template<typename OpType>
void PairWiseIntTransform<X>::exec(void *vx,
Nd4jLong xStride,
void *vy,
Nd4jLong yStride,
void *vresult,
Nd4jLong resultStride,
void *vextraParams,
const Nd4jLong n) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
} }
} }

View File

@ -442,39 +442,6 @@ namespace functions {
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
} }
template<typename T>
template<typename OpClass>
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
template<typename OpClass>
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
template<typename OpClass>
void RandomFunction<T>::execTransform(Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
template<typename T>
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
} }
} }

View File

@ -132,7 +132,7 @@ __device__ void Reduce3<X,Z>::execScalarCuda( void *vx, Nd4jLong *xShapeInfo,
extraZ[1] = (Z) 0.0f; extraZ[1] = (Z) 0.0f;
if (extraParams != nullptr) if (extraParams != nullptr)
extraZ[2] = *(static_cast<Z*>(extraParams)); extraZ[2] = static_cast<Z*>(extraParams)[2];
else else
extraZ[2] = (Z) 0.0f; extraZ[2] = (Z) 0.0f;
} }

View File

@ -27,56 +27,7 @@
namespace functions { namespace functions {
namespace reduce3 { namespace reduce3 {
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParamsVals, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo) {
}
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
}
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template <typename X, typename Y>
template<typename OpType>
void Reduce3<X,Y>::execAll(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
}
template <typename X, typename Y>
void Reduce3<X,Y>::execAll(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
}
} }
} }

View File

@ -231,41 +231,6 @@ void ScalarBoolTransform<X,Y>::executeCudaAlongDimension(dim3& launchDims, cudaS
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
template<typename X, typename Y>
template <typename OpType>
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
template<typename X, typename Y>
template<typename OpType>
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X, typename Y>
template<typename OpType>
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
} }
} }

View File

@ -230,40 +230,6 @@ void ScalarIntTransform<X>::executeCudaAlongDimension(dim3& launchDims, cudaStre
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);
template<typename X>
template <typename OpType>
void ScalarIntTransform<X,>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X>
void ScalarIntTransform<X>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
} }
} }

View File

@ -414,73 +414,6 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
} }
template <typename X, typename Y>
Y SummaryStatsReduce<X,Y>::execScalar(int opNum,
bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
return 0;
}
template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::execScalar(int opNum,
bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer) {
}
template <typename X, typename Y>
void SummaryStatsReduce<X,Y>::exec(int opNum,
bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer,
int *dimension, int dimensionLength) {
}
template <typename X, typename Y>
template<typename OpType>
Y SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams) {
return 0;
}
template <typename X, typename Y>
template<typename OpType>
void SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer) {
//
}
template <typename X, typename Y>
template<typename OpType>
void SummaryStatsReduce<X,Y>::exec(bool biasCorrected,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer,
int *dimension,
int dimensionLength) {
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES);
} }
} }

Some files were not shown because too many files have changed in this diff Show More