diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 2e3c51091..8e940bedb 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -163,9 +163,9 @@ if(CUDA_BLAS) if(CUDA_VERSION VERSION_GREATER "9.2") # cuda 10 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) endif() else() list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) @@ -173,24 +173,24 @@ if(CUDA_BLAS) elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) endif() else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8.0 if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() else() if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 ) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_52,code=sm_52 ) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() endif() @@ -205,34 +205,34 @@ if(CUDA_BLAS) message("CUDA 10 Debug build") if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) elseif() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) endif() else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=compute_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 if ("${COMPUTE}" STREQUAL "all") if (APPLE) - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) elseif() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62 -gencode arch=compute_70,code=sm_70) endif() else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_9 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() elseif (CUDA_VERSION VERSION_GREATER "7.5") # cuda 8 if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_62,code=sm_62) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_8 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() else() if ("${COMPUTE}" STREQUAL "all") - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -gencode arch=compute_30,code=sm_30 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_53,code=sm_53) else() - list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) + list(APPEND CUDA_NVCC_FLAGS -DCUDA_75 ${EXPM} -w -G -g --cudart=static --expt-extended-lambda -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) endif() endif() endif() @@ -249,7 +249,7 @@ if(CUDA_BLAS) file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h) file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h) file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h) - file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) + file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/impl/*.cpp ../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) if (NOT BUILD_TESTS) diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 87555a303..9ce90176f 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1769,6 +1769,17 @@ ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut); ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut); +typedef nd4j::LaunchContext OpaqueLaunchContext; + +ND4J_EXPORT OpaqueLaunchContext* defaultLaunchContext(); +ND4J_EXPORT Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); +ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); + } #endif //NATIVEOPERATIONS_NATIVEOPS_H diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 74bd072c8..f5d4996e4 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2985,6 +2985,38 @@ const char* runFullBenchmarkSuit(bool printOut) { return chars; } +nd4j::LaunchContext* defaultLaunchContext() { + return LaunchContext::defaultContext(); +} + +Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { + return nullptr; +} + +Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { + return nullptr; +} + BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index 67173c971..126837ad9 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -356,7 +356,7 @@ void NDArray::tile(const std::vector& reps, NDArray& target) const { auto stream = getContext()->getCudaStream(); prepareSpecialUse({&target}, {this}); - BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); registerSpecialUse({&target}, {this}); } @@ -375,7 +375,7 @@ void NDArray::tile(NDArray& target) const { auto stream = getContext()->getCudaStream(); prepareSpecialUse({&target}, {this}); - BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), tileKernelHH, (getSpecialBuffer(), getSpecialShapeInfo(), target.getSpecialBuffer(), target.getSpecialShapeInfo(), targetLen, ews, stream), LIBND4J_TYPES); registerSpecialUse({&target}, {this}); } @@ -434,7 +434,7 @@ void NDArray::repeat(int dimension, NDArray& target) const { NDArray::prepareSpecialUse({&target}, {this}); auto stream = getContext()->getCudaStream(); - BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES); NDArray::registerSpecialUse({&target}, {this}); } diff --git a/libnd4j/blas/cuda/NDArrayLambda.hpp b/libnd4j/blas/cuda/NDArrayLambda.hpp index f7846e121..bf9848981 100644 --- a/libnd4j/blas/cuda/NDArrayLambda.hpp +++ b/libnd4j/blas/cuda/NDArrayLambda.hpp @@ -23,6 +23,14 @@ #include #include +static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { + return shape::getIndexOffset(index, shapeInfo, length); +} + +static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) { + return shape::length(shapeInfo); +} + template static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, Lambda lambda); @@ -86,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -95,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL z[e * zEws] = lambda(x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(x[xOffset]); } @@ -115,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -124,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz z[e * zEws] = lambda(e, x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(e, x[xOffset]); } @@ -147,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -156,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(e, x[xOffset], y[yOffset]); } @@ -180,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -189,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(x[xOffset], y[yOffset]); } @@ -216,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = shape::length(zShapeInfo); + auto zLength = __length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -225,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto wOffset = shape::getIndexOffset(e, wShapeInfo, zLength); - auto xOffset = shape::getIndexOffset(e, xShapeInfo, zLength); - auto yOffset = shape::getIndexOffset(e, yShapeInfo, zLength); - auto zOffset = shape::getIndexOffset(e, zShapeInfo, zLength); + auto wOffset = __getIndexOffset(e, wShapeInfo, zLength); + auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); + auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); + auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); } diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 0d441fe5e..af9fc6776 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -1691,11 +1692,7 @@ void setOmpMinThreads(int threads) { } int getDevice() { - int curDevice = -1; - - cudaGetDevice(&curDevice); - - return curDevice; + return nd4j::AffinityManager::currentDeviceId(); } void setElementThreshold(int num) { @@ -2391,8 +2388,8 @@ void sortByValue(Nd4jPointer *extraPointers, auto xLength = shape::length(xShapeInfo); auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(yShapeInfo); + auto xType = nd4j::ArrayOptions::dataType(yShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(xShapeInfo); // check if xLength is a power of 2, and use bitonic sort, if that's the case @@ -2406,7 +2403,7 @@ void sortByValue(Nd4jPointer *extraPointers, for (int k = 2; k <= xLength; k = 2*k) { for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); } } } else { @@ -2430,7 +2427,7 @@ void sortByValue(Nd4jPointer *extraPointers, int rev = 0; do{ int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES); n>>=1; rev = 1; } while(n > 1); @@ -3342,6 +3339,7 @@ Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) { nd4j::graph::Context* createGraphContext(int nodeId) { return new nd4j::graph::Context(nodeId); } + nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) { return &ptr->randomGenerator(); } @@ -3460,3 +3458,35 @@ const char* runFullBenchmarkSuit(bool printOut) { Nd4jLong getCachedMemory(int deviceId) { return nd4j::ConstantHelper::getInstance()->getCachedAmount(deviceId); } + +nd4j::LaunchContext* defaultLaunchContext() { + return LaunchContext::defaultContext(); +} + +Nd4jPointer lcScalarPointer(OpaqueLaunchContext* lc) { + return lc->getScalarPointer(); +} + +Nd4jPointer lcReductionPointer(OpaqueLaunchContext* lc) { + return lc->getReductionPointer(); +} + +Nd4jPointer lcAllocationPointer(OpaqueLaunchContext* lc) { + return lc->getAllocationPointer(); +} + +Nd4jPointer lcExecutionStream(OpaqueLaunchContext* lc) { + return lc->getCudaStream(); +} + +Nd4jPointer lcCopyStream(OpaqueLaunchContext* lc) { + return lc->getCudaSpecialStream(); +} + +Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { + return lc->getCublasHandle(); +} + +Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { + return lc->getCusolverHandle(); +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/ContextDeallocator.java b/libnd4j/include/execution/AffinityManager.h similarity index 50% rename from nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/ContextDeallocator.java rename to libnd4j/include/execution/AffinityManager.h index 263316d02..463d6942e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/ContextDeallocator.java +++ b/libnd4j/include/execution/AffinityManager.h @@ -14,29 +14,33 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.jita.allocator.garbage; +// +// @author raver119@gmail.com +// -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.linalg.api.memory.Deallocator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.jcublas.context.CudaContext; +#ifndef LIBND4J_AFFINITYMANAGER_H +#define LIBND4J_AFFINITYMANAGER_H -/** - * This class provides Deallocator implementation for tracking/releasing CudaContexts once thread holding it dies - * @author raver119@gmail.com - */ -@Slf4j -public class ContextDeallocator implements Deallocator { - private CudaContext context; +#include +#include +#include +#include - public ContextDeallocator(@NonNull CudaContext context) { - this.context = context; - } +namespace nd4j { + class ND4J_EXPORT AffinityManager { + private: + static std::atomic _lastDevice; + static int _numberOfDevices; + static std::mutex _currentMutex; + static std::mutex _numberMutex; - @Override - public void deallocate() { - AtomicAllocator.getInstance().getContextPool().releaseContext(context); - } + public: + static int currentNativeDeviceId(); + static int currentDeviceId(); + static int numberOfDevices(); + static void setCurrentDevice(int deviceId); + static void setCurrentNativeDevice(int deviceId); + }; } + +#endif //DEV_TESTS_AFFINITYMANAGER_H diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h new file mode 100644 index 000000000..77b8d4ca3 --- /dev/null +++ b/libnd4j/include/execution/ContextBuffers.h @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_CONTEXTBUFFERS_H +#define LIBND4J_CONTEXTBUFFERS_H + +#include +#include + +namespace nd4j { + class ND4J_EXPORT ContextBuffers { + private: + void* _reductionPointer; + void* _scalarPointer; + void* _allocationPointer; + bool _allocated = true; + + int _deviceId = -1; + + void initialize(); + public: + ContextBuffers(); + ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner = false); + ~ContextBuffers(); + + void* reductionBuffer(); + void* scalarBuffer(); + void* allocationBuffer(); + + void setReductionBuffer(void* pointer); + void setScalarBuffer(void* pointer); + void setAllocationBuffer(void* pointer); + + void triggerOwnership(bool isOwner); + + int deviceId(); + }; +} + + +#endif //DEV_TESTS_CONTEXTBUFFERS_H diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index 853a970d2..02b772415 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -35,6 +35,8 @@ #include #include #include +#include +#include @@ -44,49 +46,44 @@ class ND4J_EXPORT LaunchContext { private: static std::vector> _contexts; + static std::mutex _mutex; #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ - void* _reductionPointer; - void* _scalarPointer; - int* _allocationPointer; - cudaStream_t *_cudaStream = nullptr; - cudaStream_t *_cudaSpecialStream = nullptr; - void *_cublasHandle = nullptr; + cudaStream_t* _cudaStream = nullptr; + cudaStream_t* _cudaSpecialStream = nullptr; + void* _cublasHandle = nullptr; + void* _cusolverHandle = nullptr; #endif // JCPP bool _isAllocated = false; #endif // CUDA - nd4j::memory::Workspace* _workspace = nullptr; - int _deviceID = 0; + nd4j::memory::Workspace* _workspace = nullptr; + int _deviceID = 0; + public: #ifdef __CUDABLAS__ #ifndef __JAVACPP_HACK__ LaunchContext(cudaStream_t* cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer = nullptr, void* scalarPointer = nullptr, int* allocationPointer = nullptr); - FORCEINLINE void* getReductionPointer () const {return _reductionPointer;}; + void* getReductionPointer () const; + void* getScalarPointer() const; + int* getAllocationPointer() const; + void* getCublasHandle() const; + void* getCusolverHandle() const; + cudaStream_t* getCudaStream() const; + cudaStream_t* getCudaSpecialStream() const; - FORCEINLINE void* getScalarPointer() const {return _scalarPointer;}; - - FORCEINLINE int* getAllocationPointer() const {return _allocationPointer;}; - - FORCEINLINE void* getCublasHandle() const {return _cublasHandle;}; - FORCEINLINE cudaStream_t* getCudaStream() const {return _cudaStream;}; - FORCEINLINE cudaStream_t* getCudaSpecialStream() const {return _cudaSpecialStream;}; - - FORCEINLINE void setReductionPointer (void* reductionPointer) {_reductionPointer = reductionPointer;}; - - FORCEINLINE void setScalarPointer(void* scalarPointer) {_scalarPointer = scalarPointer;}; - - FORCEINLINE void setAllocationPointer(int* allocationPointer) {_allocationPointer = allocationPointer;}; - - FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;}; - FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;}; - FORCEINLINE void setCublasHandle(void *handle) {_cublasHandle = handle; }; + void setReductionPointer (void* reductionPointer); + void setScalarPointer(void* scalarPointer); + void setAllocationPointer(int* allocationPointer); + void setCudaStream(cudaStream_t* cudaStream); + void setCudaSpecialStream(cudaStream_t* cudaStream); + void setCublasHandle(void *handle); #endif // JCPP diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPool.java b/libnd4j/include/execution/cpu/AffinityManager.cpp similarity index 54% rename from nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPool.java rename to libnd4j/include/execution/cpu/AffinityManager.cpp index 5045c8870..7927982a6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPool.java +++ b/libnd4j/include/execution/cpu/AffinityManager.cpp @@ -14,28 +14,30 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.jita.allocator.context; +// +// @author raver119@gmail.com +// -import org.nd4j.linalg.jcublas.context.CudaContext; +#include -/** - * This interface describes pool of CudaContext objects, used to execute kernels - * @author raver119@gmail.com - */ -public interface ContextPool { - /** - * This method returns CudaContext for given device - * @param deviceId - * @return - */ - CudaContext acquireContextForDevice(Integer deviceId); +namespace nd4j { + int AffinityManager::currentDeviceId() { + return 0; + } - @Deprecated - ContextPack acquireContextPackForDevice(Integer deviceId); + int AffinityManager::currentNativeDeviceId() { + return 0; + } - /** - * This method returns CudaContext to the pool for reuse - * @param context - */ - void releaseContext(CudaContext context); -} + int AffinityManager::numberOfDevices() { + return 1; + } + + void AffinityManager::setCurrentDevice(int deviceId) { + // no-op + } + + void AffinityManager::setCurrentNativeDevice(int deviceId) { + // no-op + } +} \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/ContextBuffers.cpp b/libnd4j/include/execution/cpu/ContextBuffers.cpp new file mode 100644 index 000000000..d385548d0 --- /dev/null +++ b/libnd4j/include/execution/cpu/ContextBuffers.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// +#include +#include + +namespace nd4j { + ContextBuffers::ContextBuffers() { + _deviceId = AffinityManager::currentDeviceId(); + } + + ContextBuffers::~ContextBuffers() { + // no-op + } + + ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; + } + + void ContextBuffers::initialize() { + // no-op + } + + void* ContextBuffers::reductionBuffer() { + return _reductionPointer; + } + + void* ContextBuffers::scalarBuffer() { + return _scalarPointer; + } + + void* ContextBuffers::allocationBuffer() { + return _allocationPointer; + } + + void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; + } + + void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; + } + + void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; + } + + void ContextBuffers::triggerOwnership(bool isOwner) { + _allocated = isOwner; + } + + int ContextBuffers::deviceId() { + return _deviceId; + } +} diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp new file mode 100644 index 000000000..47207719f --- /dev/null +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 30.11.17. +// + +#include +#include +#include +#include + +thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); + +namespace nd4j { + + LaunchContext::~LaunchContext() { + + } + + std::vector> LaunchContext::_contexts = std::vector>(); + +//////////////////////////////////////////////////////////////////////// + LaunchContext::LaunchContext() { + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; + } + + LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { + + } + + LaunchContext* LaunchContext::defaultContext() { + // TODO: we need it to be device-aware, but only once we add NUMA support for cpu + if (LaunchContext::_contexts.empty()) { + LaunchContext::_contexts.emplace_back(std::make_shared()); + } + + // return context for current device + return LaunchContext::_contexts[0].get(); + } +} \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/AffinityManager.cu b/libnd4j/include/execution/cuda/AffinityManager.cu new file mode 100644 index 000000000..811dc267a --- /dev/null +++ b/libnd4j/include/execution/cuda/AffinityManager.cu @@ -0,0 +1,108 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + +thread_local int globalThreadToDevice = -1; + +namespace nd4j { + std::mutex AffinityManager::_currentMutex; + std::mutex AffinityManager::_numberMutex; + int AffinityManager::_numberOfDevices = -1; + + int AffinityManager::currentDeviceId() { + // if there's no affinity set - set it now + if (globalThreadToDevice < 0) { + + // this block must be thread-local + _currentMutex.lock(); + + globalThreadToDevice = _lastDevice++; + + // we need to check if we've got deviceId >= number of actual devices, and reset to zero otherwise + if (globalThreadToDevice >= numberOfDevices()) { + globalThreadToDevice = 0; + _lastDevice = numberOfDevices() > 1 ? 1 : 0; + } + + _currentMutex.unlock(); + + setCurrentDevice(globalThreadToDevice); + } + + // if we already know affinity - just return it + if (globalThreadToDevice >= 0) + return globalThreadToDevice; + + int dev = 0; + auto res = cudaGetDevice(&dev); + + if (res != 0) + throw cuda_exception::build("cudaGetDevice failed", res); + + return dev; + } + + int AffinityManager::currentNativeDeviceId() { + int dev = 0; + auto res = cudaGetDevice(&dev); + + if (res != 0) + throw cuda_exception::build("cudaGetDevice failed", res); + + return dev; + } + + int AffinityManager::numberOfDevices() { + _numberMutex.lock(); + // we want to cache number of devices + if (_numberOfDevices <= 0) { + int dev = 0; + auto res = cudaGetDeviceCount(&dev); + + if (res != 0) + throw cuda_exception::build("cudaGetDeviceCount failed", res); + + _numberOfDevices = dev; + } + _numberMutex.unlock(); + + return _numberOfDevices; + } + + void AffinityManager::setCurrentNativeDevice(int deviceId) { + auto res = cudaSetDevice(deviceId); + } + + void AffinityManager::setCurrentDevice(int deviceId) { + auto res = cudaSetDevice(deviceId); + if (res != 0) + throw cuda_exception::build("cudaSetDevice failed", res); + + // update thread-device affinity + globalThreadToDevice = deviceId; + + // TODO: update context buffers? + } + + std::atomic AffinityManager::_lastDevice;// = std::atomic(initialV); +} \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu new file mode 100644 index 000000000..f82747b91 --- /dev/null +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -0,0 +1,116 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + +#include +#include +#include +#include + +namespace nd4j { + ContextBuffers::ContextBuffers() { + nd4j_printf("Creating ContextBuffers for device [%i]\n", AffinityManager::currentDeviceId()); + _deviceId = AffinityManager::currentDeviceId(); + } + + ContextBuffers::~ContextBuffers() { + if (_allocated) { + nd4j_printf("Releasing ContextBuffers\n",""); + + if (_allocationPointer != nullptr) + cudaFree(_allocationPointer); + + if (_scalarPointer != nullptr) + cudaFree(_scalarPointer); + + if (_allocationPointer != nullptr) + cudaFree(_reductionPointer); + } + } + + ContextBuffers::ContextBuffers(void* rPointer, void* sPointer, void* aPointer, bool isOwner) { + _reductionPointer = rPointer; + _scalarPointer = sPointer; + _allocationPointer = aPointer; + _allocated = isOwner; + } + + void ContextBuffers::initialize() { + nd4j_printf("Initializing buffers on deviceId [%i]\n", AffinityManager::currentNativeDeviceId()); + + auto res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); + if (res != 0) + throw std::runtime_error("_reductionPointer allocation failed"); + + res = cudaMalloc(reinterpret_cast(&_scalarPointer), 16); + if (res != 0) + throw std::runtime_error("_scalarPointer allocation failed"); + + res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8); + if (res != 0) + throw std::runtime_error("_allocationPointer allocation failed"); + + _allocated = true; + } + + void* ContextBuffers::reductionBuffer() { + if (_reductionPointer == nullptr) + initialize(); + + return _reductionPointer; + } + + void* ContextBuffers::scalarBuffer() { + if (_scalarPointer == nullptr) + initialize(); + + return _scalarPointer; + } + + void* ContextBuffers::allocationBuffer() { + if (_allocationPointer == nullptr) + initialize(); + + return _allocationPointer; + } + + void ContextBuffers::setReductionBuffer(void* pointer) { + _reductionPointer = pointer; + } + + void ContextBuffers::setScalarBuffer(void* pointer) { + _scalarPointer = pointer; + } + + void ContextBuffers::setAllocationBuffer(void* pointer) { + _allocationPointer = pointer; + } + + void ContextBuffers::triggerOwnership(bool isOwner) { + _allocated = isOwner; + } + + int ContextBuffers::deviceId() { + return _deviceId; + } +} diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu new file mode 100644 index 000000000..004ed2cac --- /dev/null +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -0,0 +1,182 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 30.11.17. +// + +#include +#include +#include +#include +#include +#include + +thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); + +namespace nd4j { + + std::vector> LaunchContext::_contexts = std::vector>(); + std::mutex LaunchContext::_mutex; + +//////////////////////////////////////////////////////////////////////// +LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { + + _cudaStream = cudaStream; + _cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream; + //_reductionPointer = reductionPointer; + //_scalarPointer = scalarPointer; + //_allocationPointer = allocationPointer; + _workspace = nullptr; + _isAllocated = false; +} + +LaunchContext::~LaunchContext() { + if (_isAllocated) { + cudaStreamSynchronize(*_cudaStream); + cudaStreamSynchronize(*_cudaSpecialStream); + + cudaStreamDestroy(*_cudaStream); + cudaStreamDestroy(*_cudaSpecialStream); + + delete _cudaStream; + delete _cudaSpecialStream; + } +} + +//////////////////////////////////////////////////////////////////////// +LaunchContext::LaunchContext() { + // default constructor, just to make clang/ranlib happy + _workspace = nullptr; + _deviceID = 0; + + _isAllocated = true; + _cudaStream = new cudaStream_t(); + _cudaSpecialStream = new cudaStream_t(); + if (nullptr == _cudaStream || nullptr == _cudaSpecialStream) + throw std::runtime_error("Failed to allocate memory for new CUDA stream"); + + cudaError_t err = cudaStreamCreate(_cudaStream); + if (err != 0) + throw cuda_exception::build("Failed to create default CUDA stream with launch context", err); + + err = cudaStreamCreate(_cudaSpecialStream); + if (err != 0) + throw cuda_exception::build("Failed to create special CUDA stream with launch context", err); + + _cublasHandle = CublasHelper::getInstance()->handle(); + + _cusolverHandle = CublasHelper::getInstance()->solver(); + + auto res = cudaStreamSynchronize(*_cudaStream); + if (res != 0) + throw cuda_exception::build("Initial sync failed", res); +} + + LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { + _isAllocated = false; + _cudaStream = reinterpret_cast(cudaStream); + _cudaSpecialStream = reinterpret_cast(cudaStream); + //_reductionPointer = reductionPointer; + //_scalarPointer = scalarPointer; + //_allocationPointer = reinterpret_cast(allocationPointer); + } + + LaunchContext* LaunchContext::defaultContext() { + /** + * This method returns LaunchContext, that has multiple entities within: + * 1) temporary buffers. they must be per-thread + * 2) CUDA stream. it must be either per-thread or per-device + * 3) cuBLAS handle. it must be per-device + */ + auto deviceId = AffinityManager::currentDeviceId(); + + // we need this block synchronous, to avoid double initialization etc + _mutex.lock(); + if (LaunchContext::_contexts.empty()) { + // create one context per device + auto numDevices = AffinityManager::numberOfDevices(); + + _contexts.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + AffinityManager::setCurrentDevice(e); + + LaunchContext::_contexts[e] = std::make_shared(); + } + + // don't forget to restore device back again + AffinityManager::setCurrentDevice(deviceId); + } + _mutex.unlock(); + + // return context for current device + return LaunchContext::_contexts[deviceId].get(); + } + + + void* LaunchContext::getReductionPointer () const { + return contextBuffers.reductionBuffer(); + }; + + void* LaunchContext::getScalarPointer() const { + return contextBuffers.scalarBuffer(); + }; + + int* LaunchContext::getAllocationPointer() const { + return reinterpret_cast(contextBuffers.allocationBuffer()); + }; + + void* LaunchContext::getCublasHandle() const { + return _cublasHandle; + }; + + void* LaunchContext::getCusolverHandle() const { + return _cusolverHandle; + }; + + cudaStream_t* LaunchContext::getCudaStream() const { + return _cudaStream; + }; + + cudaStream_t* LaunchContext::getCudaSpecialStream() const { + return _cudaSpecialStream; + }; + + + void LaunchContext::setReductionPointer (void* reductionPointer) { + contextBuffers.setReductionBuffer(reductionPointer); + }; + + void LaunchContext::setScalarPointer(void* scalarPointer) { + contextBuffers.setScalarBuffer(scalarPointer); + }; + + void LaunchContext::setAllocationPointer(int* allocationPointer) { + contextBuffers.setAllocationBuffer(allocationPointer); + }; + + void LaunchContext::setCudaStream(cudaStream_t* cudaStream) { + _cudaStream = cudaStream; + }; + + void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) { + _cudaSpecialStream = cudaStream; + }; + + void LaunchContext::setCublasHandle(void *handle) { + _cublasHandle = handle; + }; +} \ No newline at end of file diff --git a/libnd4j/include/execution/impl/LaunchContext.cpp b/libnd4j/include/execution/impl/LaunchContext.cpp deleted file mode 100644 index edc95dabc..000000000 --- a/libnd4j/include/execution/impl/LaunchContext.cpp +++ /dev/null @@ -1,130 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 30.11.17. -// - -#include -#include -#include -#include - -namespace nd4j { - -#ifdef __CUDABLAS__ - -//////////////////////////////////////////////////////////////////////// -LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { - - _cudaStream = cudaStream; - _cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream; - _reductionPointer = reductionPointer; - _scalarPointer = scalarPointer; - _allocationPointer = allocationPointer; - _workspace = nullptr; - _isAllocated = false; -} -#endif - -LaunchContext::~LaunchContext() { -#ifdef __CUDABLAS__ - if (_isAllocated) { - cudaStreamSynchronize(*_cudaStream); - cudaStreamSynchronize(*_cudaSpecialStream); - - cudaStreamDestroy(*_cudaStream); - cudaStreamDestroy(*_cudaSpecialStream); - - delete _cudaStream; - delete _cudaSpecialStream; - - cudaFree(_reductionPointer); - cudaFree(_allocationPointer); - cudaFree(_scalarPointer); - - cublas::destroyHandle(_cublasHandle); - } -#endif -} - - std::vector> LaunchContext::_contexts = std::vector>(); - -//////////////////////////////////////////////////////////////////////// -LaunchContext::LaunchContext() { - // default constructor, just to make clang/ranlib happy - _workspace = nullptr; - _deviceID = 0; - -#ifdef __CUDABLAS__ - _isAllocated = true; - _cudaStream = new cudaStream_t(); - _cudaSpecialStream = new cudaStream_t(); - if (nullptr == _cudaStream || nullptr == _cudaSpecialStream) - throw std::runtime_error("Failed to allocate memory for new CUDA stream"); - - cudaError_t err = cudaStreamCreate(_cudaStream); - if (err != 0) - throw cuda_exception::build("Failed to create default CUDA stream with launch context", err); - - err = cudaStreamCreate(_cudaSpecialStream); - if (err != 0) - throw cuda_exception::build("Failed to create special CUDA stream with launch context", err); - - _cublasHandle = cublas::handle(); - - auto res = cudaStreamSynchronize(*_cudaStream); - if (res != 0) - throw cuda_exception::build("Initial sync failed", res); - - res = cudaMalloc(reinterpret_cast(&_reductionPointer), 1024 * 1024 * 8); - if (res != 0) - throw std::runtime_error("_reductionPointer allocation failed"); - - res = cudaMalloc(reinterpret_cast(&_scalarPointer), 8); - if (res != 0) - throw std::runtime_error("_scalarPointer allocation failed"); - - res = cudaMalloc(reinterpret_cast(&_allocationPointer), 1024 * 1024 * 8); - if (res != 0) - throw std::runtime_error("_allocationPointer allocation failed"); -#else - // -#endif -} - - LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) { -#ifdef __CUDABLAS__ - _isAllocated = false; - _cudaStream = reinterpret_cast(cudaStream); - _cudaSpecialStream = reinterpret_cast(cudaStream); - _reductionPointer = reductionPointer; - _scalarPointer = scalarPointer; - _allocationPointer = reinterpret_cast(allocationPointer); -#else - // no-op -#endif - } - -LaunchContext* LaunchContext::defaultContext() { - // TODO: we need it to be device-aware - if (LaunchContext::_contexts.empty()) { - LaunchContext::_contexts.emplace_back(std::make_shared()); - } - return LaunchContext::_contexts[0].get(); -} - -} \ No newline at end of file diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index f74bd5637..43a4f97c1 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -21,6 +21,7 @@ #ifndef __CUDABLAS__ #include +#include #include #include #include @@ -59,11 +60,11 @@ namespace nd4j { } int ConstantHelper::getCurrentDevice() { - return 0L; + return AffinityManager::currentDeviceId(); } int ConstantHelper::getNumberOfDevices() { - return 1; + return AffinityManager::numberOfDevices(); } ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 293360a25..d17d2c021 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -21,6 +21,7 @@ #include "../MmulHelper.h" #include #include +#include namespace nd4j { @@ -147,7 +148,12 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX, ////////////////////////////////////////////////////////////////////////////// // MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + if (A->dataType() != B->dataType()) + throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), B->dataType()); + + if (C != nullptr && A->dataType() != C->dataType()) + throw datatype_exception::build("mmulMxM expects all data types to be the same", A->dataType(), C->dataType()); if(A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !"); @@ -212,7 +218,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (double) beta, reinterpret_cast(pC->getBuffer()), ldc); } else { - BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), NUMERIC_TYPES); + //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } if(pC != C) { @@ -230,6 +237,11 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con //////////////////////////////////////////////////////////////////////////// // MXN x N = M NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { + if (X->dataType() != A->dataType()) + throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType()); + + if (Y != nullptr && X->dataType() != Y->dataType()) + throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType()); int xLenDim, yLenDim(0); @@ -279,7 +291,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy); } else { - BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), NUMERIC_TYPES); + //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } if(pA != A) @@ -291,6 +304,11 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* //////////////////////////////////////////////////////////////////////////// // (X * Y) = Z[0] NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) { + if (X->dataType() != Y->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Y->dataType()); + + if (Z != nullptr && X->dataType() != Z->dataType()) + throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType()); int xLenDim(0), yLenDim(0); @@ -316,13 +334,14 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c const auto yType = Y->dataType(); const auto zType = Z->dataType(); - BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES); + //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); return Z; } -BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } diff --git a/libnd4j/include/helpers/cpu/cublasHelper.cpp b/libnd4j/include/helpers/cpu/cublasHelper.cpp index cc2a4029a..3dba2d31e 100644 --- a/libnd4j/include/helpers/cpu/cublasHelper.cpp +++ b/libnd4j/include/helpers/cpu/cublasHelper.cpp @@ -21,13 +21,41 @@ #include "../cublasHelper.h" namespace nd4j { - namespace cublas { - void* handle() { - return nullptr; - } - - void destroyHandle(void* handle) { - // - } + static void* handle_() { + return nullptr; } + + static void destroyHandle_(void* handle) { + + } + + CublasHelper::CublasHelper() { + + } + + CublasHelper::~CublasHelper() { + + } + + CublasHelper* CublasHelper::getInstance() { + if (!_INSTANCE) + _INSTANCE = new nd4j::CublasHelper(); + + return _INSTANCE; + } + + void* CublasHelper::handle() { + return nullptr; + } + + void* CublasHelper::solver() { + return nullptr; + } + + void* CublasHelper::handle(int deviceId) { + return nullptr; + } + + + nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0; } \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/loops/IndexReductionLoops.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/IndexReductionLoops.cpp rename to libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_0.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_0.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_0.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_1.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_1.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_1.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_2.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_2.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_2.cpp diff --git a/libnd4j/include/helpers/impl/loops/Reduction3Loops_3.cpp b/libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/Reduction3Loops_3.cpp rename to libnd4j/include/helpers/cpu/loops/Reduction3Loops_3.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops.hpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops.hpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops.hpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops.hpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_bool.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_bool.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_0.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_0.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_0.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_1.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_1.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_1.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_2.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_2.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_2.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_float_3.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_float_3.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_float_3.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_long.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_long.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp diff --git a/libnd4j/include/helpers/impl/loops/ReductionLoops_same.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp similarity index 100% rename from libnd4j/include/helpers/impl/loops/ReductionLoops_same.cpp rename to libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp diff --git a/libnd4j/include/helpers/cublasHelper.h b/libnd4j/include/helpers/cublasHelper.h index bff16b2d4..d4f92881e 100644 --- a/libnd4j/include/helpers/cublasHelper.h +++ b/libnd4j/include/helpers/cublasHelper.h @@ -21,12 +21,28 @@ #ifndef DEV_TESTS_CUBLASHELPER_H #define DEV_TESTS_CUBLASHELPER_H -namespace nd4j { - namespace cublas { - void* handle(); +#include +#include +#include - void destroyHandle(void* handle); - } +namespace nd4j { + class CublasHelper { + private: + static CublasHelper *_INSTANCE; + + std::vector _cache; + std::vector _solvers; + + CublasHelper(); + ~CublasHelper(); + public: + static CublasHelper* getInstance(); + + void* solver(); + + void* handle(); + void* handle(int deviceId); + }; } #endif //DEV_TESTS_CUBLASHELPER_H diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index d0579b66d..0c7f2cbc1 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -26,6 +26,7 @@ #include #include #include +#include #define CONSTANT_LIMIT 49152 @@ -43,23 +44,11 @@ namespace nd4j { } int ConstantHelper::getCurrentDevice() { - int dev = 0; - auto res = cudaGetDevice(&dev); - - if (res != 0) - throw cuda_exception::build("cudaGetDevice failed", res); - - return dev; + return AffinityManager::currentDeviceId(); } int ConstantHelper::getNumberOfDevices() { - int dev = 0; - auto res = cudaGetDeviceCount(&dev); - - if (res != 0) - throw cuda_exception::build("cudaGetDeviceCount failed", res); - - return dev; + return AffinityManager::numberOfDevices(); } diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index 56e726004..ac5eb4176 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -250,8 +250,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou blocksPerGrid.y = math::nd4j_ceil(static_cast(M) / threadsPerBlock.y); // rows } - BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - // BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES) + //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES) } if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); @@ -339,8 +339,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* threadsPerBlock.x = 512; blocksPerGrid.x = math::nd4j_ceil(static_cast(M) / threadsPerBlock.x); // rows } - BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - // BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES) + //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES) } if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); @@ -397,8 +397,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c NDArray::prepareSpecialUse({Z}, {X, Y}); - BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - // BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) + //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) auto cudaResult = cudaStreamSynchronize(*stream); if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); @@ -408,8 +408,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c return Z; } -BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); -BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); +//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda_off/cublasHelper.cu b/libnd4j/include/helpers/cuda_off/cublasHelper.cu index f80bf1f87..6f2cf2084 100644 --- a/libnd4j/include/helpers/cuda_off/cublasHelper.cu +++ b/libnd4j/include/helpers/cuda_off/cublasHelper.cu @@ -20,12 +20,15 @@ #include +#include #include "../cublasHelper.h" #include #include +#include namespace nd4j { - void* cublas::handle() { + + static void* handle_() { auto _handle = new cublasHandle_t(); auto status = cublasCreate_v2(_handle); // initialize CUBLAS context if (status != CUBLAS_STATUS_SUCCESS) @@ -34,7 +37,16 @@ namespace nd4j { return reinterpret_cast(_handle); } - void cublas::destroyHandle(void* handle) { + static void* solver_() { + auto cusolverH = new cusolverDnHandle_t(); + auto status = cusolverDnCreate(cusolverH); + if (status != CUSOLVER_STATUS_SUCCESS) + throw cuda_exception::build("cuSolver handle creation failed !", status); + + return cusolverH; + } + + static void destroyHandle_(void* handle) { auto ch = reinterpret_cast(handle); auto status = cublasDestroy_v2(*ch); if (status != CUBLAS_STATUS_SUCCESS) @@ -42,4 +54,57 @@ namespace nd4j { delete ch; } + + CublasHelper::CublasHelper() { + auto numDevices = AffinityManager::numberOfDevices(); + auto currentDevice = AffinityManager::currentDeviceId(); + _cache.resize(numDevices); + _solvers.resize(numDevices); + for (int e = 0; e < numDevices; e++) { + AffinityManager::setCurrentDevice(e); + + _cache[e] = handle_(); + _solvers[e] = solver_(); + } + + // don't forget to restore back original device + AffinityManager::setCurrentDevice(currentDevice); + } + + CublasHelper::~CublasHelper() { + auto numDevices = AffinityManager::numberOfDevices(); + + for (int e = 0; e < numDevices; e++) + destroyHandle_(_cache[e]); + } + + CublasHelper* CublasHelper::getInstance() { + if (!_INSTANCE) + _INSTANCE = new nd4j::CublasHelper(); + + return _INSTANCE; + } + + void* CublasHelper::handle() { + auto deviceId = AffinityManager::currentDeviceId(); + return handle(deviceId); + } + + void* CublasHelper::solver() { + auto deviceId = AffinityManager::currentDeviceId(); + if (deviceId < 0 || deviceId > _solvers.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); + + return _solvers[deviceId]; + } + + void* CublasHelper::handle(int deviceId) { + if (deviceId < 0 || deviceId > _cache.size()) + throw cuda_exception::build("requested deviceId doesn't look valid", deviceId); + + return _cache[deviceId]; + } + + + nd4j::CublasHelper* nd4j::CublasHelper::_INSTANCE = 0; } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting.chpp b/libnd4j/include/loops/cuda/broadcasting.chpp index e673f4eae..dc8a3eeb1 100644 --- a/libnd4j/include/loops/cuda/broadcasting.chpp +++ b/libnd4j/include/loops/cuda/broadcasting.chpp @@ -60,9 +60,18 @@ static __global__ void broadcastInverseSimple( functions::broadcast::Broadcast::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } + namespace functions { namespace broadcast { + static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { + return shape::getIndexOffset(index, shapeInfo, length); + } + + static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) { + return shape::length(shapeInfo); + } + template template __host__ void Broadcast::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { @@ -120,9 +129,9 @@ namespace functions { if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); + tadLength = _length(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(yShapeInfo) / tadLength; + numTads = _length(yShapeInfo) / tadLength; xEWS = shape::elementWiseStride(xShapeInfo); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); } @@ -146,9 +155,9 @@ namespace functions { else { // it is expected that x and z tads and y array all have the same length for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength); - auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + auto xOffset = _getIndexOffset(i, xShapeInfo, tadLength); + auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); } } @@ -186,9 +195,9 @@ namespace functions { if (threadIdx.x == 0) { - tadLength = shape::length(tadOnlyShapeInfo); + tadLength = _length(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = shape::length(xShapeInfo) / tadLength; + numTads = _length(xShapeInfo) / tadLength; yEWS = shape::elementWiseStride(yShapeInfo); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); } @@ -212,14 +221,15 @@ namespace functions { // it is expected that x and z tads and y array all have the same length for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); - auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength); - auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto yOffset = _getIndexOffset(i, yShapeInfo, tadLength); + auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); } } } } + /* BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1); diff --git a/libnd4j/include/loops/cuda/broadcasting.cu b/libnd4j/include/loops/cuda/broadcasting.cu new file mode 100644 index 000000000..8028db2ba --- /dev/null +++ b/libnd4j/include/loops/cuda/broadcasting.cu @@ -0,0 +1,115 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace functions { + namespace broadcast { + template + void Broadcast::execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + // + } + + template + void Broadcast::exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + template + void Broadcast::exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + // + } + + + template + template + void Broadcast::execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index 6acf71356..6cc3f3cbb 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -224,6 +224,77 @@ namespace functions { } } + + template + void BroadcastBool::exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + void BroadcastBool::execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastBool::exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastBool::execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index 7c17538fa..94793f8e8 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -361,6 +361,32 @@ namespace functions { } } + + + + template + Nd4jLong IndexReduce::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { + return 0; + } + + template + void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + + } + + template + template + Nd4jLong IndexReduce:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) { + return 0; + } + + template + template + _CUDA_H void IndexReduce::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/pairwise.cu b/libnd4j/include/loops/cuda/pairwise.cu new file mode 100644 index 000000000..17f8537e5 --- /dev/null +++ b/libnd4j/include/loops/cuda/pairwise.cu @@ -0,0 +1,79 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../pairwise_transform.h" + +namespace functions { + namespace pairwise_transforms { + template + void PairWiseTransform::exec( + const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + void *extraParams) { + + } + + template + void PairWiseTransform::exec( + const int opNum, + void *x, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *z, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong len) { + + } + + + template + template + void PairWiseTransform:: exec( + void *vx, + Nd4jLong* xShapeInfo, + void *vy, + Nd4jLong* yShapeInfo, + void *vresult, + Nd4jLong* zShapeInfo, + void *vextraParams) { + + } + + template + template + void PairWiseTransform::exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong len) { + + } + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise_bool.cu b/libnd4j/include/loops/cuda/pairwise_bool.cu index 41bca38cb..0834386f2 100644 --- a/libnd4j/include/loops/cuda/pairwise_bool.cu +++ b/libnd4j/include/loops/cuda/pairwise_bool.cu @@ -110,6 +110,63 @@ void PairWiseBoolTransform::executeCudaShaped(dim3& launchDims, cudaStream_ DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS); } + + + template + void PairWiseBoolTransform::exec( + const int opNum, + void *dx, + Nd4jLong *xShapeBuffer, + void *y, + Nd4jLong *yShapeBuffer, + void *result, + Nd4jLong *resultShapeBuffer, + void *extraParams) { + + } + + template + void PairWiseBoolTransform::exec( + const int opNum, + void *dx, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *result, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong n) { + + } + + + template + template + void PairWiseBoolTransform::exec( + void *vx, + Nd4jLong* xShapeBuffer, + void *vy, + Nd4jLong* yShapeBuffer, + void *vresult, + Nd4jLong* resultShapeBuffer, + void *vextraParams) { + + } + + template + template + void PairWiseBoolTransform::exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong n) { + + } + + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); } diff --git a/libnd4j/include/loops/cuda/random.cu b/libnd4j/include/loops/cuda/random.cu index 4cc1c6565..727f0868f 100644 --- a/libnd4j/include/loops/cuda/random.cu +++ b/libnd4j/include/loops/cuda/random.cu @@ -442,6 +442,39 @@ namespace functions { DEBUG_KERNEL(stream, opNum); } + template + template + void RandomFunction::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + template + void RandomFunction::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + template + void RandomFunction::execTransform(Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + + template + void RandomFunction::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) { + + } + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cuda/reduce3.cu b/libnd4j/include/loops/cuda/reduce3.cu new file mode 100644 index 000000000..1ad94beee --- /dev/null +++ b/libnd4j/include/loops/cuda/reduce3.cu @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include +#include +#include +#include +#include + +namespace functions { + namespace reduce3 { + template + template + void Reduce3::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) { + + } + + + template + void Reduce3::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParamsVals, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo) { + + } + + + template + template + void Reduce3::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) { + + } + + + template + template + void Reduce3::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + + template + template + void Reduce3::execAll(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { + + } + + + template + void Reduce3::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) { + + } + + + template + void Reduce3::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + + template + void Reduce3::execAll(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { + + } + + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ExternalContext.java b/libnd4j/include/loops/cuda/scalar.cu similarity index 66% rename from nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ExternalContext.java rename to libnd4j/include/loops/cuda/scalar.cu index 08ec374b2..67cbc7a98 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ExternalContext.java +++ b/libnd4j/include/loops/cuda/scalar.cu @@ -14,21 +14,19 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.nd4j.jita.allocator.context; +// +// @author raver119@gmail.com +// -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; +#include "loops/scalar.h" +#include +#include +#include +#include +#include -/** - * This is simple class-independant storage for device contexts. - * - * TODO: Something better then typecast required here - * @author raver119@gmail.com - */ -@Data -@NoArgsConstructor -@AllArgsConstructor -public class ExternalContext { - private Object context; -} +namespace functions { + namespace scalar { + + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar_bool.cu b/libnd4j/include/loops/cuda/scalar_bool.cu index a5a26d7e7..c6563c9ef 100644 --- a/libnd4j/include/loops/cuda/scalar_bool.cu +++ b/libnd4j/include/loops/cuda/scalar_bool.cu @@ -231,6 +231,41 @@ void ScalarBoolTransform::executeCudaAlongDimension(dim3& launchDims, cudaS } BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); + + + template + template + void ScalarBoolTransform::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarBoolTransform::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarBoolTransform::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + template + void ScalarBoolTransform::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } + + template + template + void ScalarBoolTransform::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + + template + template + void ScalarBoolTransform::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } } } diff --git a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu index 7584949cc..8ee950c25 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicArbitraryStep.cu @@ -21,84 +21,6 @@ #include -////////////////////////////////////////////////////////////////////////// -template -__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) { - auto x = static_cast(vx); - auto y = static_cast(vy); - - int tid = threadIdx.x + blockDim.x * blockIdx.x; - int half = window>>1; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) { - xLength = shape::length(xShapeInfo); - } - __syncthreads(); - - //for (int i = 0; i < length; i+= window) - /* - if window == 4; - iterations will be: 0; 4; 8; 12; 16; 20 - if gridDim = 3; - on first iteration we'll have: 0; 4; 8; - on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20 - */ - int firstPosition; - int firstStep; - int secondPosition; - int secondStep; - - int WARP_SIZE = 32; - int numWarps = (gridDim.x * blockDim.x) / 32; - int warpId = tid / WARP_SIZE; - int warpIdx = tid % WARP_SIZE; - - if (half >= 128) { - firstPosition = blockIdx.x * window; - firstStep = gridDim.x * window; - - secondPosition = threadIdx.x; - secondStep = blockDim.x; - } else if (half >= 32) { - firstPosition = warpId * window; - firstStep = numWarps * window; - - secondPosition = warpIdx; - secondStep = WARP_SIZE; - } else { - firstPosition = tid * window; - firstStep = blockDim.x * gridDim.x * window; - - secondPosition = 0; - secondStep = 1; - } - - - for (int i = firstPosition; i < length; i += firstStep) { - for (int j = secondPosition; j < half; j += secondStep) { - int it = (reverse) ? i + j + half : i + window - j - 1; - int ij = i+j; - if (it < length && ij < length ) { - int posIT = shape::getIndexOffset(it, yShapeInfo, xLength); - int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength); - - Y v0 = y[posIJ]; - Y v1 = y[posIT]; - - if(!descending == (v0 > v1)) { - y[posIJ] = v1; - y[posIT] = v0; - - X xtemp = x[posIJ]; - x[posIJ] = x[posIT]; - x[posIT] = xtemp; - } - } - } - } -} - ////////////////////////////////////////////////////////////////////////// template __global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) { @@ -264,11 +186,5 @@ __host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *str bitonicArbitraryStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); } -template -__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) { - bitonicArbitraryStepKernelValue<<>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending); -} - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES); BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu index 3e1a0edc5..d9b2ec74c 100644 --- a/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu +++ b/libnd4j/include/loops/cuda/specials/bitonicSortStep.cu @@ -21,60 +21,6 @@ #include -////////////////////////////////////////////////////////////////////////// -template -__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) { - - auto x = static_cast(vx); - auto y = static_cast(vy); - - unsigned int i, ixj; /* Sorting partners: i and ixj */ - i = threadIdx.x + blockDim.x * blockIdx.x; - - __shared__ Nd4jLong xLength; - if (threadIdx.x == 0) - xLength = shape::length(xShapeInfo); - - __syncthreads(); - - - if (i >= length) - return; - - ixj = i^j; - - /* The threads with the lowest ids sort the array. */ - if ((ixj)>i) { - int posI = shape::getIndexOffset(i, yShapeInfo, xLength); - int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength); - - if ((i&k)==0) { - /* Sort ascending */ - if (!descending == (y[posI]>y[posIXJ])) { - /* exchange(i,ixj); */ - X temp = x[posI]; - x[posI] = x[posIXJ]; - x[posIXJ] = temp; - - Y ytemp = y[posI]; - y[posI] = y[posIXJ]; - y[posIXJ] = ytemp; - } - } else if ((i&k)!=0) { - /* Sort descending */ - if (!descending == (y[posI] @@ -189,13 +135,6 @@ __host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, bitonicSortStepKernelKey<<>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); } -////////////////////////////////////////////////////////////////////////// -template -__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) { - bitonicSortStepKernelValue<<>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending); -} - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES); BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); -BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES); diff --git a/libnd4j/include/loops/cuda/specials/repeatKernel.cu b/libnd4j/include/loops/cuda/specials/repeatKernel.cu index 5193aca2a..c3177049f 100644 --- a/libnd4j/include/loops/cuda/specials/repeatKernel.cu +++ b/libnd4j/include/loops/cuda/specials/repeatKernel.cu @@ -62,9 +62,9 @@ namespace nd4j { } } } - BUILD_DOUBLE_TEMPLATE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer, + BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets, - Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES, LIBND4J_TYPES); + Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES); template void repeatKernelH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong outputLength, @@ -88,10 +88,10 @@ namespace nd4j { dim3 launchDims(256, 512, 8192); repeatKernelDouble<<>>(inputBuffer, outputBuffer, numTads, inputLength, tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets); } - BUILD_DOUBLE_TEMPLATE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, + BUILD_SINGLE_TEMPLATE_TWICE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets, Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets, - cudaStream_t stream), LIBND4J_TYPES, LIBND4J_TYPES); + cudaStream_t stream), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tileKernel.cu b/libnd4j/include/loops/cuda/specials/tileKernel.cu index d2c62ced7..7d2e87e2d 100644 --- a/libnd4j/include/loops/cuda/specials/tileKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tileKernel.cu @@ -21,6 +21,17 @@ #include namespace nd4j { + static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { + return shape::getIndexOffset(index, shapeInfo, length); + } + + static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) { + return shape::subArrayOffset(index, shapeInfoA, shapeInfoB); + } + + static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) { + return shape::length(shapeInfo); + } //////////////////////////////////////////////////////////////////////// template @@ -34,31 +45,20 @@ namespace nd4j { //const auto resultLength = shape::length(outputShape); if (shape::order(outputShape) == 'c') { // ews == 1 always here for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); *(reinterpret_cast(outputBuffer) + i) = *(reinterpret_cast(inputBuffer) + yOffset); } -// for(Nd4jLong i=0; itemplate templatedAssign, (newBuff, i, this->_buffer, yOffset), LIBND4J_TYPES); -// -// } } else { -// - //auto inputLength = shape::lenght(inputShape); for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = shape::getIndexOffset(i, outputShape, resultLength); - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = *(reinterpret_cast(inputBuffer) + - yOffset); -// BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (newBuff, xOffset, this->_buffer, yOffset), LIBND4J_TYPES); + auto xOffset = _getIndexOffset(i, outputShape, resultLength); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = *(reinterpret_cast(inputBuffer) + yOffset); } } } - BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel, - (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength), - LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template __global__ void tileKernel,(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength), LIBND4J_TYPES); template void tileKernelH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, cudaStream_t *stream) { @@ -77,29 +77,26 @@ namespace nd4j { if (ordering == 'c' && ews == 1) { // ews == 1 always here for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i) = static_cast(*(reinterpret_cast(inputBuffer) + - yOffset)); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } else if (ordering == 'c' && ews > 1) { for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + i * ews) = static_cast(*( - reinterpret_cast(inputBuffer) + yOffset)); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + i * ews) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } else { for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = shape::getIndexOffset(i, outputShape, resultLength); - auto yOffset = shape::subArrayOffset(i, outputShape, inputShape); - *(reinterpret_cast(outputBuffer) + xOffset) = static_cast(*( - reinterpret_cast(inputBuffer) + yOffset)); + auto xOffset = _getIndexOffset(i, outputShape, resultLength); + auto yOffset = _subArrayOffset(i, outputShape, inputShape); + *(reinterpret_cast(outputBuffer) + xOffset) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } } - BUILD_DOUBLE_TEMPLATE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void tileKernelDouble, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews), LIBND4J_TYPES); template void tileKernelHH(void const *inputBuffer, Nd4jLong *inputShape, void *outputBuffer, Nd4jLong *outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream) { @@ -107,5 +104,5 @@ namespace nd4j { tileKernelDouble<<>>(inputBuffer, inputShape, outputBuffer, outputShape, resultLength, ews); } - BUILD_DOUBLE_TEMPLATE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE_TWICE(template void tileKernelHH, (void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream),LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/summarystatsreduce.cu b/libnd4j/include/loops/cuda/summarystatsreduce.cu index cb3d06a4b..1e2f3ce4f 100644 --- a/libnd4j/include/loops/cuda/summarystatsreduce.cu +++ b/libnd4j/include/loops/cuda/summarystatsreduce.cu @@ -413,6 +413,74 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa DEBUG_KERNEL(stream, opNum); } + + template + Y SummaryStatsReduce::execScalar(int opNum, + bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams) { + return 0; + } + + template + void SummaryStatsReduce::execScalar(int opNum, + bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer) { + + } + + template + void SummaryStatsReduce::exec(int opNum, + bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer, + int *dimension, int dimensionLength) { + + } + + template + template + Y SummaryStatsReduce::execScalar(bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams) { + return 0; + } + + template + template + void SummaryStatsReduce::execScalar(bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer) { + // + } + + + template + template + void SummaryStatsReduce::exec(bool biasCorrected, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *vz, + Nd4jLong *resultShapeInfoBuffer, + int *dimension, + int dimensionLength) { + + } + + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/transform/transform_any.cu b/libnd4j/include/loops/cuda/transform/transform_any.cu index a217167a6..34f56380a 100644 --- a/libnd4j/include/loops/cuda/transform/transform_any.cu +++ b/libnd4j/include/loops/cuda/transform/transform_any.cu @@ -114,6 +114,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformAny(...) failed"); } + template + void TransformAny::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) { + + } + + template + template + void TransformAny::exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) { + + } + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_bool.cu b/libnd4j/include/loops/cuda/transform/transform_bool.cu index bff361fcb..a01221cfa 100644 --- a/libnd4j/include/loops/cuda/transform/transform_bool.cu +++ b/libnd4j/include/loops/cuda/transform/transform_bool.cu @@ -120,6 +120,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformBool(...) failed"); } + template + void TransformBool::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformBool::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_float.cu b/libnd4j/include/loops/cuda/transform/transform_float.cu index 05d4c9999..e1cd36256 100644 --- a/libnd4j/include/loops/cuda/transform/transform_float.cu +++ b/libnd4j/include/loops/cuda/transform/transform_float.cu @@ -142,6 +142,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed"); } + template + void TransformFloat::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformFloat::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES); } diff --git a/libnd4j/include/loops/cuda/transform/transform_same.cu b/libnd4j/include/loops/cuda/transform/transform_same.cu index 1e9bf2d64..a0d137d64 100644 --- a/libnd4j/include/loops/cuda/transform/transform_same.cu +++ b/libnd4j/include/loops/cuda/transform/transform_same.cu @@ -118,6 +118,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformSame(...) failed"); } + template + void TransformSame::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformSame::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES); } } diff --git a/libnd4j/include/loops/cuda/transform/transform_strict.cu b/libnd4j/include/loops/cuda/transform/transform_strict.cu index 8a5b65c04..10385812d 100644 --- a/libnd4j/include/loops/cuda/transform/transform_strict.cu +++ b/libnd4j/include/loops/cuda/transform/transform_strict.cu @@ -119,6 +119,17 @@ namespace functions { nd4j::DebugHelper::checkErrorCode(stream, "transformStrict(...) failed"); } + template + void TransformStrict::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + + template + template + void TransformStrict::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + + } + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES); } } diff --git a/libnd4j/include/loops/cpu/type_conversions.cpp b/libnd4j/include/loops/impl/type_conversions.cpp similarity index 96% rename from libnd4j/include/loops/cpu/type_conversions.cpp rename to libnd4j/include/loops/impl/type_conversions.cpp index 3c923de39..dc85b9554 100644 --- a/libnd4j/include/loops/cpu/type_conversions.cpp +++ b/libnd4j/include/loops/impl/type_conversions.cpp @@ -209,15 +209,6 @@ PRAGMA_OMP_ATOMIC_ARGS(write) } }; - _CUDA_H Nd4jLong TypeCast::estimateQuantizedSize(Nd4jLong rawSize) { - if (rawSize <= 0) - throw std::runtime_error("Input size for quantization can't be <= 0"); - - // 2 fp32 values for max/min, and rawSize number of BYTES - return 8 + rawSize; - } - - template void TypeCast::convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); template void TypeCast::convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); template void TypeCast::convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); diff --git a/libnd4j/include/loops/type_conversions.h b/libnd4j/include/loops/type_conversions.h index 1c54f41d4..d6029d7af 100644 --- a/libnd4j/include/loops/type_conversions.h +++ b/libnd4j/include/loops/type_conversions.h @@ -69,7 +69,14 @@ namespace nd4j { template static _CUDA_H void convertFromThreshold(Nd4jPointer * extras, void *dx, Nd4jLong N, void *dz); - static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize); + FORCEINLINE static _CUDA_H Nd4jLong estimateQuantizedSize(Nd4jLong rawSize) { + if (rawSize <= 0) + throw std::runtime_error("Input size for quantization can't be <= 0"); + + // 2 fp32 values for max/min, and rawSize number of BYTES + return 8 + rawSize; + } + template static _CUDA_H void convertToQuantized(Nd4jPointer *extras, void *dx, Nd4jLong N, void *dz); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index 7541ab841..1e0330294 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -75,7 +75,7 @@ namespace nd4j { DECLARE_TYPES(non_max_suppression) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); + ->setAllowedOutputTypes({ALL_INDICES}); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp index fa9ab7b40..5484d822d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_hue.cpp @@ -87,8 +87,7 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { - - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), FLOAT_TYPES); } /* diff --git a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp index d01a8e2be..9a5141a82 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/adjust_saturation.cpp @@ -89,7 +89,7 @@ static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarA void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), FLOAT_TYPES); } /* diff --git a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp index b29a79504..b4a54ad7a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/col2im.cpp @@ -119,11 +119,9 @@ void col2im_(nd4j::LaunchContext & context, const NDArray& input, NDArray& outp void col2im(nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), col2im_, (context, input, output, sH, sW, pH, pW, iH, iW, dH, dW), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void col2im_, (nd4j::LaunchContext & context, const NDArray& input, NDArray& output, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW), LIBND4J_TYPES); - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index 6d319d993..033e0b5e5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -2445,71 +2445,52 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); } void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), FLOAT_TYPES); } void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); } void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), FLOAT_TYPES); } void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); } - - - BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - - BUILD_SINGLE_TEMPLATE(template void upsampling2d_, (const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void upsampling3d_, (const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void upsampling2dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void upsampling3dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void vol2col_, (const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void col2vol_, (const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp index c9a4e0fb5..c75bbf131 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/dilation2d.cpp @@ -81,10 +81,8 @@ static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const } } -BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES); - void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp index 49626168c..97cd2f84e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp @@ -76,7 +76,7 @@ namespace nd4j { double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp index 131165117..002c68226 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/im2col.cpp @@ -122,11 +122,9 @@ static void im2col_(nd4j::LaunchContext & context, const NDArray& input, NDArra void im2col(nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { - BUILD_SINGLE_SELECTOR(im.dataType(), im2col_, (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(im.dataType(), im2col_, (context, im, col, kH, kW, sH, sW, pH, pW, dH, dW, arrZeroPadVal), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void im2col_, (nd4j::LaunchContext & context, const NDArray& im, NDArray& col, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal), LIBND4J_TYPES); - } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 062db8d87..2ac679fc5 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -334,10 +334,6 @@ namespace helpers { BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, (images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); } - - BUILD_TRIPLE_TEMPLATE(template void cropAndResizeFunctor_, - (NDArray const* images, NDArray const* boxes, NDArray const* indices, NDArray const* cropSize, int method, double extrapolationVal, NDArray* crops), - NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index 52a41a201..45024b5cb 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -32,7 +32,6 @@ namespace helpers { theFirst->applyPairwiseLambda(theSecond, functor, nullptr); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative__, (NDArray* input, NDArray* epsilon), FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); @@ -46,7 +45,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); @@ -61,8 +59,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void relu6Derivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -76,8 +72,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void leakyReluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -91,8 +85,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void eluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -106,8 +98,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void seluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -121,8 +111,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void cubeDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -137,8 +125,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reduceNorm1_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -153,8 +139,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropy_, (NDArray* logits, NDArray* labels, NDArray* output);, FLOAT_TYPES); - void sigmCrossEntropy(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); } @@ -173,8 +157,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropyGrad_, (NDArray* logits, NDArray* labels, NDArray*output);, FLOAT_TYPES); - void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); } @@ -190,8 +172,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void tanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -207,8 +187,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -222,8 +200,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rationalTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -237,8 +213,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rectifiedTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -256,8 +230,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softSignDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softSignDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -272,8 +244,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softPlusDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -291,8 +261,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -306,8 +274,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardSigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -347,13 +313,10 @@ namespace helpers { void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* axis, NDArray*output);, FLOAT_TYPES); void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* subtrah, NDArray* axis, NDArray*output);, FLOAT_TYPES); - ////////////////////////////////////////////////////////////////////////// template @@ -393,7 +356,6 @@ static void weightedCrossEntropyWithLogitsFunctor_(NDArray const* targets, NDArr void weightedCrossEntropyWithLogitsFunctor(nd4j::LaunchContext * context, NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output) { BUILD_SINGLE_SELECTOR(targets->dataType(), weightedCrossEntropyWithLogitsFunctor_, (targets, input, weights, output), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void weightedCrossEntropyWithLogitsFunctor_, (NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output), FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp index 75b23c932..a02d5918c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lrn.cpp @@ -410,10 +410,9 @@ static void lrnBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI, c gradI *= gradO; } -BUILD_DOUBLE_TEMPLATE(template void lrnBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta), LIBND4J_TYPES, FLOAT_TYPES); void lrnBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (input, gradO, gradI, depth, bias, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 29d9f463b..1fb1ef1df 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -345,8 +345,6 @@ template int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); - BUILD_SINGLE_TEMPLATE(template int inverse_, (NDArray* input, NDArray* output), FLOAT_TYPES); template int logdetFunctor_(NDArray* input, NDArray* output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/matmul.cpp b/libnd4j/include/ops/declarable/helpers/cpu/matmul.cpp deleted file mode 100644 index 6990f2dc3..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/matmul.cpp +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.12.17. -// - -#include - -namespace nd4j { - namespace ops { - namespace helpers { - template - void __matmul(NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - CBLAS_TRANSPOSE tA = (CBLAS_TRANSPOSE) transA; - CBLAS_TRANSPOSE tB = (CBLAS_TRANSPOSE) transB; - - int M = vA->sizeAt(0); - int N = vB->sizeAt(1); - int K = vA->sizeAt(1); - - int ldA = transA == CblasNoTrans ? M : K; - int ldB = transB == CblasNoTrans ? K : N; - int ldC = M; - - auto A = reinterpret_cast(vA->buffer()); - auto B = reinterpret_cast(vB->buffer()); - auto C = reinterpret_cast(vC->buffer()); - - PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) - for (int m = 0; m < M; ++m) { - for (int n = 0; n < N; ++n) { - Z c_mnp = 0; - - for (int k = 0; k < K; ++k) - c_mnp += (Z) A[tA == CblasNoTrans ? (m + k * ldA) : (m * ldA + k)] * (Z) B[tB == CblasNoTrans ? (k + n * ldB) : (k * ldB + n)]; - - C[m + n * ldC] = (Z) alpha * (Z) c_mnp + (Z) beta * (Z) C[m + n * ldC]; - } - } - } - - - void _matmul(nd4j::LaunchContext * context, NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - BUILD_TRIPLE_SELECTOR(vA->dataType(), vB->dataType(), vC->dataType(), __matmul, (vA, vB, vC, transA, transB, alpha, beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - - BUILD_TRIPLE_TEMPLATE(template void __matmul, (NDArray *A, NDArray *B, NDArray *C, int transA, int transB, double alpha, double beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - } -} diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index 06fe2eec2..6ebca9184 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -76,9 +76,6 @@ namespace helpers { BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); } - - BUILD_SINGLE_TEMPLATE(template void maxPoolingFunctor_, (nd4j::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices), FLOAT_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp index c38df008f..0fc6eea0b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/toggle_bits.cpp @@ -32,7 +32,6 @@ namespace nd4j { in.applyLambda(lambda, &out); } - BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); void __toggle_bits(nd4j::LaunchContext * context, NDArray& in, NDArray& out) { BUILD_SINGLE_SELECTOR(in.dataType(), toggle_bits__, (in, out), INTEGER_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp index 71641f215..3536f9f62 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp @@ -56,9 +56,6 @@ static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const N BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, (context, input, gradO, gradI, diagonal), LIBND4J_TYPES); } - -BUILD_SINGLE_TEMPLATE(template void triuBP_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template static void trace_(const NDArray& input, NDArray& output) { @@ -78,8 +75,6 @@ static void trace_(const NDArray& input, NDArray& output) { BUILD_SINGLE_SELECTOR(input.dataType(), trace_, (input, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void trace_, (const NDArray& input, NDArray& output), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) { @@ -173,14 +168,6 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerato BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); - - - - - - - ////////////////////////////////////////////////////////////////////////// template @@ -387,8 +374,6 @@ void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, co BUILD_SINGLE_SELECTOR(input.dataType(), pad_, (mode, input, paddings, output, padValue), LIBND4J_TYPES); } -BUILD_SINGLE_TEMPLATE(template void pad_, (const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, NDArray const& padValue), LIBND4J_TYPES); - //////////////////////////////////////////////////////////////////////// /*// initial values of inIdx, outIdx, dim must be equal to zero template @@ -623,9 +608,8 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) { //////////////////////////////////////////////////////////////////////// void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { - BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INDEXING_TYPES); } -BUILD_DOUBLE_TEMPLATE(template void gatherND_, (NDArray& input, NDArray& indices, NDArray& output), LIBND4J_TYPES, INTEGER_TYPES); //////////////////////////////////////////////////////////////////////// @@ -705,8 +689,6 @@ static void gather_(NDArray* input, const NDArray* indices, NDArray* output, con BUILD_SINGLE_SELECTOR(input->dataType(), gather_, (input, indices, output, intArgs), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void gather_, (NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// void eye(nd4j::LaunchContext * context, NDArray& output) { @@ -826,7 +808,6 @@ static void mergeMaxIndex_(const std::vector& inArrs, NDArray& output) BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeMaxIndex_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -850,8 +831,6 @@ static void mergeMax_(const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeMax_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template static void mergeAvg_(const std::vector& inArrs, NDArray& output) { @@ -874,7 +853,6 @@ static void mergeAvg_(const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeAvg_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -898,8 +876,6 @@ static void mergeAdd_(const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (inArrs, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// template static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { @@ -970,11 +946,6 @@ void clipByNorm(nd4j::LaunchContext * context, NDArray& input, NDArray& output, BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); } -BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); - - - - diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index 33805e335..1397874f8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -99,7 +99,7 @@ void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& a const auto yType = alpha.dataType(); NDArray::prepareSpecialUse({&output}, {&input, &alpha}); - BUILD_DOUBLE_SELECTOR(xType, yType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(xType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input, &alpha}); manager.synchronize(); @@ -189,7 +189,7 @@ void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& const auto zType = alpha.dataType(); NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); - BUILD_DOUBLE_SELECTOR(xType, zType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(xType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), FLOAT_TYPES); NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO}); manager.synchronize(); @@ -574,14 +574,6 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr BUILD_SINGLE_SELECTOR(input->dataType(), thresholdReluDerivative_, (input, threshold, dLdO, output), FLOAT_TYPES); } - -BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void preluCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void preluBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void softMaxForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES); -BUILD_SINGLE_TEMPLATE(template void softMaxDerivForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES); - - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu index def7d316f..e8062e126 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_hue.cu @@ -78,7 +78,6 @@ static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int thr adjustHueCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e(0), dimC); } -BUILD_SINGLE_TEMPLATE(template void adjustHueCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { @@ -94,7 +93,7 @@ void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray PointersManager manager(context, "adjustHue"); NDArray::prepareSpecialUse({output}, {input, deltaScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), FLOAT_TYPES); NDArray::registerSpecialUse({output}, {input, deltaScalarArr}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu index ce910a892..4ab8da304 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/adjust_saturation.cu @@ -80,7 +80,6 @@ static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const adjustSaturationCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e(0), dimC); } -BUILD_SINGLE_TEMPLATE(template void adjustSaturationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { @@ -96,7 +95,7 @@ void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const PointersManager manager(context, "adjustSaturation"); NDArray::prepareSpecialUse({output}, {input, factorScalarArr}); - BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), FLOAT_TYPES); NDArray::registerSpecialUse({output}, {input, factorScalarArr}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu index 8a5dbd744..7678779ac 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/batchnorm.cu @@ -182,7 +182,6 @@ __host__ static void batchnormCudaLauncher(const int blocksPerGrid, const int th batchnormCuda<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, static_cast(epsilon)); } -BUILD_SINGLE_TEMPLATE(template void batchnormCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vMean, const Nd4jLong* meanShapeInfo, const void* vVariance, const Nd4jLong* varianceShapeInfo, const void* vGamma, const Nd4jLong* gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const double epsilon), FLOAT_TYPES); /////////////////////////////////////////////////////////////////// template @@ -198,7 +197,6 @@ __host__ static void batchnormCudaLauncher2(const int blocksPerGrid, const int t batchnormCuda2<<>>(vx, xShapeInfo, vMean, meanShapeInfo, vVariance, varianceShapeInfo, vGamma, gammaShapeInfo, vBeta, betaShapeInfo, vz, zShapeInfo, numDims, dims, static_cast(epsilon)); } -BUILD_SINGLE_TEMPLATE(template void batchnormCudaLauncher2, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vMean, const Nd4jLong* meanShapeInfo, const void* vVariance, const Nd4jLong* varianceShapeInfo, const void* vGamma, const Nd4jLong* gammaShapeInfo, const void* vBeta, const Nd4jLong* betaShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int numDims, const int* dims, const double epsilon), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void batchnorm(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/bds.cu b/libnd4j/include/ops/declarable/helpers/cuda/bds.cu index 6aef74adb..ef501eac0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/bds.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/bds.cu @@ -107,7 +107,6 @@ namespace helpers { return Status::OK(); return Status::OK(); } - BUILD_SINGLE_TEMPLATE(template void bdsLoopH, (cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape), NUMERIC_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu index 2088e18fe..e02bce146 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/col2im.cu @@ -189,7 +189,6 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc // col2imCuda2<<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW); col2imCuda<<>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW); } -BUILD_SINGLE_TEMPLATE(template void col2imCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *col, const Nd4jLong *colShapeInfo, void *im, const Nd4jLong *imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) { @@ -201,7 +200,7 @@ void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&im}, {&col}); - BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); NDArray::registerSpecialUse({&im}, {&col}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index e993b370e..44a0156d7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -98,7 +98,6 @@ static void vol2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlo vol2colCuda<<>>(volume, volShapeInfo, columns, colShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); } -BUILD_SINGLE_TEMPLATE(template void vol2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *vol, const Nd4jLong *volShapeInfo, void *col, const Nd4jLong *colShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { @@ -205,7 +204,6 @@ static void col2volCudaLauncher(const int blocksPerGrid, const int threadsPerBlo col2volCuda<<>>(columns, colShapeInfo, volume, volShapeInfo, sD, sH, sW, pD, pH, pW, dD, dH, dW); } -BUILD_SINGLE_TEMPLATE(template void col2volCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *col, const Nd4jLong *colShapeInfo, void *vol, const Nd4jLong *volShapeInfo, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { @@ -285,7 +283,7 @@ static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDA ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -345,7 +343,7 @@ static void depthwiseConv2d_(const NDArray* input, const NDArray* weights, const ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -390,7 +388,7 @@ static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const ND ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -488,7 +486,6 @@ template static void avgPooling2dCudaLauncher(nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { avgPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); } -BUILD_DOUBLE_TEMPLATE(template void avgPooling2dCudaLauncher, (nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -582,7 +579,6 @@ template static void pnormPooling2dCudaLauncher(nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { pnormPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); } -BUILD_DOUBLE_TEMPLATE(template void pnormPooling2dCudaLauncher, (nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -679,7 +675,6 @@ template static void maxPooling2dCudaLauncher(nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0) { maxPooling2dCuda<<<512, 512, 4192, *block.getCudaStream()>>>(vx, vxShapeInfo, vz, vzShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, extraParam0); } -BUILD_DOUBLE_TEMPLATE(template void maxPooling2dCudaLauncher, (nd4j::LaunchContext & block, void *vx, Nd4jLong *vxShapeInfo, void *vz, Nd4jLong *vzShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int extraParam0), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { @@ -689,15 +684,15 @@ void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& inp switch (poolingMode) { case MAX_POOL: { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), maxPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); } break; case AVG_POOL: { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), avgPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); } break; case PNORM_POOL: { - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input.dataType(), pnormPooling2dCudaLauncher, (*block.launchContext(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, extraParam0), FLOAT_TYPES); } break; default: @@ -845,7 +840,6 @@ static void pooling3dCudaLauncher(const int blocksPerGrid, const int threadsPerB pooling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); } -BUILD_SINGLE_TEMPLATE(template void pooling3dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -857,49 +851,12 @@ void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& inp const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); manager.synchronize(); } - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ////////////////////////////////////////////////////////////////////////// template __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -1032,7 +989,6 @@ static void pooling2dBPCudaLauncher(const int blocksPerGrid, const int threadsPe pooling2dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0); } -BUILD_SINGLE_TEMPLATE(template void pooling2dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -1047,7 +1003,7 @@ void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& i const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); manager.synchronize(); @@ -1201,7 +1157,6 @@ static void pooling3dBPCudaLauncher(const int blocksPerGrid, const int threadsPe pooling3dBPCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0); } -BUILD_SINGLE_TEMPLATE(template void pooling3dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { @@ -1216,7 +1171,7 @@ void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& i const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); manager.synchronize(); @@ -1292,11 +1247,10 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N delete gradI; } } -BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -1374,11 +1328,10 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con delete gradI; } } -BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); } @@ -1434,7 +1387,6 @@ static void upsampling2dCudaLauncher(const int blocksPerGrid, const int threadsP upsampling2dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorH, factorW, isNCHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling2dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { @@ -1446,7 +1398,7 @@ void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorH, factorW, isNCHW), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); manager.synchronize(); @@ -1505,7 +1457,6 @@ static void upsampling3dCudaLauncher(const int blocksPerGrid, const int threadsP upsampling3dCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, factorD, factorH, factorW, isNCDHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling3dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { @@ -1517,7 +1468,7 @@ void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&output}, {&input}); - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), factorD, factorH, factorW, isNCDHW), FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&input}); manager.synchronize(); @@ -1579,7 +1530,6 @@ static void upsampling2dBPCudaLauncher(const int blocksPerGrid, const int thread upsampling2dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling2dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { @@ -1591,7 +1541,7 @@ void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling2dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCHW), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&gradO}); manager.synchronize(); @@ -1656,7 +1606,6 @@ static void upsampling3dBPCudaLauncher(const int blocksPerGrid, const int thread upsampling3dBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, isNCDHW); } -BUILD_SINGLE_TEMPLATE(template void upsampling3dBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const bool isNCDHW), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { @@ -1668,7 +1617,7 @@ void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray const int sharedMem = gradI.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({&gradI}, {&gradO}); - BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(gradI.dataType(), upsampling3dBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), isNCDHW), FLOAT_TYPES); NDArray::registerSpecialUse({&gradI}, {&gradO}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu index 423944e0f..f4dff2279 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu @@ -100,19 +100,12 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha input->syncToDevice(); diagPartFunctorKernel<<>>(output->specialBuffer(), output->specialShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), outLen, inLen); -// int i(0), j; -// for (j = 0;j < outLen; j++) { -// output->p(j, input->e(i)); -// i += outLen + 1; -// } - } - BUILD_SINGLE_TEMPLATE(template void _diagPartFunctor, (nd4j::LaunchContext * context, const NDArray* input, NDArray* output);, LIBND4J_TYPES); void diagPartFunctor(nd4j::LaunchContext * context, NDArray const* input, NDArray* output) { auto zType = output->dataType(); - BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(zType, _diagPartFunctor, (context, input, output), NUMERIC_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu index e23c4c84f..a636af891 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dilation2d.cu @@ -114,8 +114,6 @@ static void dilation2dCudaLauncher(const int blocksPerGrid, const int threadsPer dilation2dCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW); } -BUILD_DOUBLE_TEMPLATE(template void dilation2dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES); - void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { PointersManager manager(context, "dilation2d"); @@ -125,7 +123,7 @@ void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, const int sharedMem = (weights->rankOf() + output->rankOf()) * sizeof(Nd4jLong) * threadsPerBlock + 128; NDArray::prepareSpecialUse({output}, {input, weights}); - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), weights->getSpecialBuffer(), weights->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), weights->getSpecialBuffer(), weights->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES); NDArray::registerSpecialUse({output}, {input, weights}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index 952bf47c7..a01b4f555 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -73,8 +73,6 @@ namespace helpers { NDArray::registerSpecialUse({output}, {input}); } - BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); - template int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { @@ -124,8 +122,6 @@ namespace helpers { BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); - /////////////////////////////////// backrpopagations /////////////////////////////////////////////// template static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong* outputShape, void* gradOutBuf, Nd4jLong* gradOutShape, double probValue) { @@ -260,17 +256,14 @@ namespace helpers { int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue), FLOAT_TYPES); int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index d6a2d26bb..857ebed38 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -306,7 +306,7 @@ namespace nd4j { NDArray::prepareSpecialUse({}, {indices, input}); - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, (context, input, indices, outputList), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicPartitionFunctor, (context, input, indices, outputList), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({}, {indices, input}); @@ -336,7 +336,7 @@ namespace nd4j { NDArray::prepareSpecialUse({output}, {}); - BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, (context, inputs, indices, output), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, _dynamicStitchFunctor, (context, inputs, indices, output), NUMERIC_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {}); @@ -346,22 +346,15 @@ namespace nd4j { int dynamicStitchFunctorBP(nd4j::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList) { auto xType = inputs.at(0)->dataType(); - BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, return _dynamicStitchFunctorBP, (inputs, indices, gradInput, outputList), NUMERIC_TYPES); } void dynamicPartitionFunctorBP(nd4j::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList) { auto xType = input->dataType(); - BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, _dynamicPartitionFunctorBP, (input, indices, inputGradientList, outputList), NUMERIC_TYPES); } - BUILD_SINGLE_TEMPLATE(template void _dynamicPartitionFunctorBP, (NDArray const* input, NDArray const* indices, std::vector const& inputGradientList, std::vector& outputList);, LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template int _dynamicStitchFunctorBP, (std::vector const& inputs, std::vector const& indices, NDArray const* gradInput, std::vector& outputList);, LIBND4J_TYPES); - - BUILD_DOUBLE_TEMPLATE(template void _dynamicPartitionFunctor, (nd4j::LaunchContext * context, NDArray const* input, NDArray const* indices, std::vector& outputList);, LIBND4J_TYPES, INTEGER_TYPES); - BUILD_DOUBLE_TEMPLATE(template int _dynamicStitchFunctor, (nd4j::LaunchContext * context, std::vector const& inputs, std::vector const& indices, NDArray* output);, LIBND4J_TYPES, INTEGER_TYPES); - - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu index 5415ddab1..aabd9e949 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu @@ -164,13 +164,13 @@ void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* sizeof(Nd4jLong))); NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLauncher, (context->getCudaStream(), numOfSubArrs, input->getSpecialBuffer(), xShapeInfo, xOffsets, pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->getSpecialBuffer(), zShapeInfo, zOffsets), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLauncher, (context->getCudaStream(), numOfSubArrs, input->getSpecialBuffer(), xShapeInfo, xOffsets, pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->getSpecialBuffer(), zShapeInfo, zOffsets), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, pIndices}); manager.synchronize(); } else { NDArray::prepareSpecialUse({output}, {input, pIndices}); - BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLinear, (context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), pIndices->dataType(), gatherCudaLinear, (context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), pIndices->getSpecialBuffer(), pIndices->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {input, pIndices}); } @@ -181,12 +181,6 @@ void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* } } - -BUILD_DOUBLE_TEMPLATE(template void gatherCudaLauncher, (const cudaStream_t *stream, const int numOfSubArrs, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xOffsets, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zOffsets), NUMERIC_TYPES, INTEGER_TYPES); -BUILD_DOUBLE_TEMPLATE(template void gatherCudaLinear, (const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo), NUMERIC_TYPES, INTEGER_TYPES); - - - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu index 614ac95c1..71dc284a6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu @@ -120,7 +120,6 @@ namespace nd4j { gatherNDCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } - BUILD_DOUBLE_TEMPLATE(template void gatherNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), LIBND4J_TYPES, INTEGER_TYPES); /////////////////////////////////////////////////////////////////// void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { @@ -137,7 +136,7 @@ namespace nd4j { PointersManager manager(context, "gatherND"); NDArray::prepareSpecialUse({&output}, {&input, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&output}, {&input, &indices}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index e04b1b57a..eda19ccd8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -125,7 +125,7 @@ namespace nd4j { double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&output}, {&input}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu index 73cae9d80..3e8ec6836 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/im2col.cu @@ -85,7 +85,6 @@ template static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) { im2colCuda<<>>(image, columns, imShapeInfo, colShapeInfo, sH, sW, pH, pW, dH, dW, zeroPadVal); } -BUILD_SINGLE_TEMPLATE(template void im2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext& context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadVal), LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// void im2col(nd4j::LaunchContext& context, const NDArray& image, NDArray& columns, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) { @@ -96,7 +95,7 @@ void im2col(nd4j::LaunchContext& context, const NDArray& image, NDArray& columns const int blocksPerGrid = (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; NDArray::prepareSpecialUse({&columns}, {&image}); - BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.getSpecialBuffer(), columns.getSpecialBuffer(), image.getSpecialShapeInfo(), columns.getSpecialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.getSpecialBuffer(), columns.getSpecialBuffer(), image.getSpecialShapeInfo(), columns.getSpecialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e(0)), FLOAT_TYPES); NDArray::registerSpecialUse({&columns}, {&image}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index cd6887bf0..2cec0a065 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -85,8 +85,8 @@ namespace helpers { *shouldSelect = shouldSelectShared; } } - template + template static __global__ void copyIndices(void* indices, void* indicesLong, Nd4jLong len) { __shared__ I* indexBuf; __shared__ Nd4jLong* srcBuf; @@ -115,15 +115,15 @@ namespace helpers { sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); // TO DO: sort indices using scales as value row //std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e(i) > scales->e(j);}); - I* indexBuf = reinterpret_cast(indices->specialBuffer()); + auto indexBuf = reinterpret_cast(indices->specialBuffer()); NDArray selectedIndices = NDArrayFactory::create('c', {output->lengthOf()}); int numSelected = 0; int numBoxes = boxes->sizeAt(0); - T* boxesBuf = reinterpret_cast(boxes->specialBuffer()); + auto boxesBuf = reinterpret_cast(boxes->specialBuffer()); - I* selectedIndicesData = reinterpret_cast(selectedIndices.specialBuffer()); - I* outputBuf = reinterpret_cast(output->specialBuffer()); + auto selectedIndicesData = reinterpret_cast(selectedIndices.specialBuffer()); + auto outputBuf = reinterpret_cast(output->specialBuffer()); bool* shouldSelectD; auto err = cudaMalloc(&shouldSelectD, sizeof(bool)); @@ -138,8 +138,7 @@ namespace helpers { throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err); } - shouldSelectKernel <<< 128, 256, 1024, *stream >>> - (boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD); + shouldSelectKernel<<<128, 256, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD); err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost); if (err) { throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err); @@ -161,9 +160,8 @@ namespace helpers { } void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { - BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, (context, boxes, scales, maxSize, threshold, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(boxes->dataType(), output->dataType(), nonMaxSuppressionV2_, (context, boxes, scales, maxSize, threshold, output), FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void nonMaxSuppressionV2_, (nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu index 46f972f44..a0f30a116 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu @@ -34,7 +34,6 @@ namespace nd4j { theFirst->applyPairwiseLambda(theSecond, functor, nullptr); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative__, (NDArray* input, NDArray* epsilon), FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative__, (theFirst, theSecond), FLOAT_TYPES); @@ -48,7 +47,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); @@ -63,8 +61,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void relu6Derivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), relu6Derivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -78,8 +74,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void leakyReluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -93,8 +87,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void eluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -108,8 +100,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void seluDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), seluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu index 9ad1ee0ad..017180b38 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/tanh.cu @@ -36,8 +36,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void tanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void tanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), tanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -53,8 +51,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -68,8 +64,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rationalTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rationalTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rationalTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -83,8 +77,6 @@ namespace nd4j { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void rectifiedTanhDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void rectifiedTanhDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), rectifiedTanhDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu index 6d0788c64..defdfaf09 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy_helper.cu @@ -35,8 +35,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void cubeDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), cubeDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -51,8 +49,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void reduceNorm1_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), reduceNorm1_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -67,8 +63,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropy_, (NDArray* logits, NDArray* labels, NDArray* output);, FLOAT_TYPES); - void sigmCrossEntropy(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropy_, (logits, labels, output), FLOAT_TYPES); } @@ -87,8 +81,6 @@ namespace helpers { logits->applyPairwiseLambda(labels, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmCrossEntropyGrad_, (NDArray* logits, NDArray* labels, NDArray*output);, FLOAT_TYPES); - void sigmCrossEntropyGrad(nd4j::LaunchContext * context, NDArray* logits, NDArray* labels, NDArray* output) { BUILD_SINGLE_SELECTOR(logits->dataType(), sigmCrossEntropyGrad_, (logits, labels, output), FLOAT_TYPES); } @@ -106,8 +98,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softSignDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softSignDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softSignDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -122,8 +112,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void softPlusDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void softPlusDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), softPlusDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -141,8 +129,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void sigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void sigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), sigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -156,8 +142,6 @@ namespace helpers { input->applyPairwiseLambda(epsilon, functor, output); } - BUILD_SINGLE_TEMPLATE(template void hardSigmoidDerivative_, (NDArray* input, NDArray* epsilon, NDArray*output);, FLOAT_TYPES); - void hardSigmoidDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { BUILD_SINGLE_SELECTOR(theFirst->dataType(), hardSigmoidDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); } @@ -197,12 +181,10 @@ namespace helpers { void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* axis, NDArray*output);, FLOAT_TYPES); void logSumExp(nd4j::LaunchContext * context, NDArray* input, NDArray* subtrah, NDArray* axis, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), logSumExp_, (input, subtrah, axis, output), FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void logSumExp_, (NDArray* input, NDArray* subtrah, NDArray* axis, NDArray*output);, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -246,7 +228,7 @@ void weightedCrossEntropyWithLogitsFunctor(nd4j::LaunchContext * context, NDArra NDArray::registerSpecialUse({output}, {targets, input, weights}); } -BUILD_SINGLE_TEMPLATE(template void weightedCrossEntropyWithLogitsFunctor_, (NDArray const* targets, NDArray const* input, NDArray const* weights, NDArray* output), FLOAT_TYPES); + } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu index baabf6574..f27511b3a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu @@ -148,7 +148,7 @@ namespace helpers { input.syncToDevice(); gradO.syncToDevice(); - BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); gradI.tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 9c8dff3a5..ffd652ee7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -212,8 +212,6 @@ namespace helpers { invertLowKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE); - void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); } @@ -232,8 +230,6 @@ namespace helpers { invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE); - void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); } @@ -562,8 +558,6 @@ namespace helpers { return Status::OK(); } - BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); - int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); } @@ -612,8 +606,6 @@ namespace helpers { return ND4J_STATUS_OK; } - BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); - int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/matmul.cu b/libnd4j/include/ops/declarable/helpers/cuda/matmul.cu deleted file mode 100644 index 322966836..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/matmul.cu +++ /dev/null @@ -1,39 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// Created by raver119 on 20.12.17. -// - -#include - -namespace nd4j { - namespace ops { - namespace helpers { - template - void __matmul(NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - - } - - - void _matmul(nd4j::LaunchContext * context, NDArray *vA, NDArray *vB, NDArray *vC, int transA, int transB, double alpha, double beta) { - BUILD_TRIPLE_SELECTOR(vA->dataType(), vB->dataType(), vC->dataType(), __matmul, (vA, vB, vC, transA, transB, alpha, beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - - BUILD_TRIPLE_TEMPLATE(template void __matmul, (NDArray *A, NDArray *B, NDArray *C, int transA, int transB, double alpha, double beta), LIBND4J_TYPES, LIBND4J_TYPES, LIBND4J_TYPES); - } - } -} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu index d3aa58a9c..d5af6328a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu @@ -88,13 +88,10 @@ namespace helpers { void maxPoolingFunctor(nd4j::LaunchContext * context, nd4j::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { NDArray::prepareSpecialUse({values, indices}, {input}); auto yType = indices == nullptr ? nd4j::DataType::INT64 : indices->dataType(); - BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({values, indices}, {input}); } - - BUILD_DOUBLE_TEMPLATE(template void maxPoolingFunctor_, (nd4j::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices), FLOAT_TYPES, INTEGER_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu index 0af1f0eda..a2aec252e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/maximum.cu @@ -107,7 +107,6 @@ namespace nd4j { NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); } - BUILD_SINGLE_TEMPLATE(template void maximumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu index 3c8d159be..ceb748453 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu @@ -79,10 +79,9 @@ namespace nd4j { } void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void mergeMaxIndex_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES, INTEGER_TYPES); ////////////////////////////////////////////////////////////////////////// template @@ -128,7 +127,6 @@ namespace nd4j { manager.synchronize(); } - BUILD_SINGLE_TEMPLATE(template void mergeMax_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); void mergeMax(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES); @@ -176,10 +174,9 @@ namespace nd4j { manager.synchronize(); } - BUILD_SINGLE_TEMPLATE(template void mergeAvg_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); void mergeAvg(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES); } ////////////////////////////////////////////////////////////////////////// @@ -224,10 +221,10 @@ namespace nd4j { manager.synchronize(); } - BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), NUMERIC_TYPES); void mergeAdd(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu index 2647a53df..ea4a1e146 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/meshgrid.cu @@ -136,7 +136,7 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// void meshgrid(nd4j::LaunchContext * context, const std::vector& inArrs, const std::vector& outArrs, const bool swapFirst2Dims) { - BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, (context, inArrs, outArrs, swapFirst2Dims), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(inArrs.at(0)->dataType(), meshgrid_, (context, inArrs, outArrs, swapFirst2Dims), NUMERIC_TYPES); for (auto v:outArrs) v->tickWriteDevice(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu index 12f888005..75c73f96b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/minimum.cu @@ -109,8 +109,6 @@ namespace nd4j { NDArray::registerSpecialUse({gradX, gradY}, {x, y, epsNext}); } - BUILD_SINGLE_TEMPLATE(template void minimumBPFunctor_, (NDArray* x, NDArray* y, NDArray* epsNext, NDArray* gradX, NDArray* gradY), NUMERIC_TYPES); - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu index 80662a19b..aeddd3b97 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu @@ -88,8 +88,7 @@ namespace helpers { BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (context, input, n, output, reverse), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, (nd4j::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse), LIBND4J_TYPES); - + } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu index b268e6366..ef74180c8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/pad.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/pad.cu @@ -128,7 +128,6 @@ namespace nd4j { padCuda<<>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal); } - BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES); /////////////////////////////////////////////////////////////////// void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) { @@ -144,7 +143,7 @@ namespace nd4j { const auto xType = input.dataType(); const auto yType = paddings.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, padCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), mode, input.getSpecialBuffer(), input.getSpecialShapeInfo(), paddings.getSpecialBuffer(), paddings.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), padValue.getSpecialBuffer()), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&output}, {&input, &paddings, &padValue}); manager.synchronize(); @@ -272,11 +271,9 @@ namespace nd4j { } void mirrorPad(nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode) { - BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input.dataType(), paddings.dataType(), mirrorPad_, (context, input, paddings, output, mode), LIBND4J_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void mirrorPad_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode), LIBND4J_TYPES, INTEGER_TYPES); - } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu index 90b9e5d5f..53cfcc22d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/prefix.cu @@ -160,7 +160,7 @@ void prefix(nd4j::LaunchContext * context, scalar::Ops op, const NDArray* x, NDA PointersManager manager(context, "prefix"); NDArray::prepareSpecialUse({z}, {x}); - BUILD_SINGLE_SELECTOR(x->dataType(), prefixPerBlockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, x->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLen, exclusive, reverse), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(x->dataType(), prefixPerBlockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, x->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), z->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLen, exclusive, reverse), NUMERIC_TYPES); NDArray::registerSpecialUse({z}, {x}); manager.synchronize(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/range.cu b/libnd4j/include/ops/declarable/helpers/cuda/range.cu index 323877e47..7e8ddb2a7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/range.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/range.cu @@ -46,7 +46,7 @@ namespace helpers { BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), NUMERIC_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu index b6f0c215a..776d92c45 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu @@ -415,7 +415,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } else { @@ -426,7 +426,7 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } NDArray::registerSpecialUse({&output}, {&updates, &indices}); @@ -714,7 +714,7 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDLockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), output.getSpecialShapeInfo(), packX.numberOfTads(), packZ.numberOfTads(), shape::length(packY.primaryShapeInfo())), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDLockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), output.getSpecialShapeInfo(), packX.numberOfTads(), packZ.numberOfTads(), shape::length(packY.primaryShapeInfo())), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } else { @@ -725,7 +725,7 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i const auto xType = indices.dataType(); const auto yType = updates.dataType(); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); } NDArray::registerSpecialUse({&output}, {&updates, &indices}); @@ -797,26 +797,18 @@ void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArra if(calcGrad) { NDArray::prepareSpecialUse({&updates}, {&indices}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INTEGER_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INDEXING_TYPES, FLOAT_TYPES); NDArray::registerSpecialUse({&updates}, {&indices}); } else { NDArray::prepareSpecialUse({&output}, {&indices, &updates}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INTEGER_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INDEXING_TYPES, FLOAT_TYPES); NDArray::registerSpecialUse({&output}, {&indices, &updates}); } manager.synchronize(); } - - - -BUILD_DOUBLE_TEMPLATE(template void scatterCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); -BUILD_DOUBLE_TEMPLATE(template void scatterLockCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); -BUILD_DOUBLE_TEMPLATE(template void scatterNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); -BUILD_DOUBLE_TEMPLATE(template void scatterNDLockCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int opCode, const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong *zShapeInfo, const Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong zTadLen), INTEGER_TYPES, GENERIC_NUMERIC_TYPES); - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu index 5d3c4eb52..f1eda6b01 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter_simple.cu @@ -70,7 +70,7 @@ namespace nd4j { NDArray::prepareSpecialUse({&input}, {&updates, &indices}); - BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterSimple_, (context, opId, input, updates, indices, dimensions), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({&input}, {&updates, &indices}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu index 67cb77b5c..4aa5c762d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment.cu @@ -40,12 +40,9 @@ namespace helpers { } bool segmentIndicesValidate(nd4j::LaunchContext* context , NDArray* indices, NDArray& expected, NDArray& output) { - BUILD_DOUBLE_SELECTOR(output.dataType(), indices->dataType(), return segmentIndicesValidate_, (indices, expected, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output.dataType(), indices->dataType(), return segmentIndicesValidate_, (indices, expected, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, NDArray&, NDArray&), NUMERIC_TYPES, INTEGER_TYPES); - - // -------------------------------------------------------------------------------------------------------------- // // Unsorted segment ops functors implementation // -------------------------------------------------------------------------------------------------------------- // @@ -85,9 +82,9 @@ namespace helpers { } bool unsortedSegmentIndicesValidate(nd4j::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - BUILD_SINGLE_SELECTOR(indices->dataType(), return unsortedSegmentIndicesValidate_, (context, indices, expected, output), INTEGER_TYPES); + BUILD_SINGLE_SELECTOR(indices->dataType(), return unsortedSegmentIndicesValidate_, (context, indices, expected, output), INDEXING_TYPES); } - BUILD_SINGLE_TEMPLATE(template bool unsortedSegmentIndicesValidate_, (nd4j::LaunchContext* context , NDArray* indices, Nd4jLong expected, Nd4jLong& output), INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- // @@ -126,9 +123,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens) { - BUILD_SINGLE_SELECTOR(indices->dataType(), fillUpSegments_, (indices, numClasses, classesRangesBegs, classesRangesLens), INTEGER_TYPES); + BUILD_SINGLE_SELECTOR(indices->dataType(), fillUpSegments_, (indices, numClasses, classesRangesBegs, classesRangesLens), INDEXING_TYPES); } - BUILD_SINGLE_TEMPLATE(template void fillUpSegments_, (NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens), INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index a1792750f..20796b1d1 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -201,9 +201,8 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentMaxFunctor_, (LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -241,10 +240,9 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // // segment max // -------------------------------------------------------------------------------------------------------------- // @@ -371,10 +369,8 @@ namespace nd4j { // -------------------------------------------------------------------------------------------------------------- // int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int segmentMaxFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -418,10 +414,8 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index 19c50728a..c60272188 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -186,9 +186,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentMeanFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // template static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { @@ -226,10 +226,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMeanFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -351,11 +349,9 @@ namespace helpers { // segmen mean bp main int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int segmentMeanFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // template static int unsortedSegmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { @@ -402,12 +398,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMeanFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index b5c76e18d..de602201b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -192,9 +192,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // @@ -235,11 +234,9 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); template static __global__ void segmentMinBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, @@ -366,10 +363,8 @@ namespace helpers { // segmen min int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template int segmentMinFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // template static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { @@ -412,12 +407,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index 0a7c73040..7454756b5 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -192,9 +192,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentProdFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -233,10 +232,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -360,11 +357,9 @@ namespace helpers { int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, - indices, gradOut, output), FLOAT_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int segmentProdFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template @@ -407,10 +402,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentProdFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index 6e3ab24d9..875f63e77 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -147,9 +147,8 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), - FLOAT_TYPES, INTEGER_TYPES); + FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSqrtNFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // template static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, @@ -270,11 +269,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSqrtNFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 4f2cc92a1..1d9d983ef 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -190,9 +190,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template void segmentSumFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); + // -------------------------------------------------------------------------------------------------------------- // template static void unsortedSegmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { @@ -230,11 +230,9 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), - NUMERIC_TYPES, INTEGER_TYPES); + NUMERIC_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSumFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); // -------------------------------------------------------------------------------------------------------------- // // Backpropagate ops @@ -344,10 +342,8 @@ namespace helpers { int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, - indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); + indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); } - BUILD_DOUBLE_TEMPLATE(template int segmentSumFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); - // -------------------------------------------------------------------------------------------------------------- // template static int unsortedSegmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { @@ -383,10 +379,8 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); } - // -------------------------------------------------------------------------------------------------------------- // - BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSumFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu index fd2f3db6c..150c616a6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sru.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sru.cu @@ -231,7 +231,6 @@ static void sruBICudaLauncher(const int blocksPerGrid, const int threadsPerBlock sruBICuda<<>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void sruBICudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, void* vht, const Nd4jLong* htShapeInfo, void* vct, const Nd4jLong* ctShapeInfo), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu index 36b369113..db6213dd3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu @@ -101,7 +101,7 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con const auto yType = targets->dataType(); NDArray::prepareSpecialUse({output}, {predictions, targets}); - BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->getSpecialBuffer(), predictions->getSpecialShapeInfo(), targets->getSpecialBuffer(), targets->getSpecialShapeInfo(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->getSpecialBuffer(), predictions->getSpecialShapeInfo(), targets->getSpecialBuffer(), targets->getSpecialShapeInfo(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {predictions, targets}); manager.synchronize(); @@ -269,7 +269,7 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con int topKFunctor(nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { input->syncToDevice(); - BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INTEGER_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INDEXING_TYPES); values->tickWriteDevice(); indices->tickWriteDevice(); @@ -277,9 +277,6 @@ int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, con return Status::OK(); } - - BUILD_DOUBLE_TEMPLATE(template int topKFunctor_, (nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort), LIBND4J_TYPES, INTEGER_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 19c726581..bb311ed01 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -68,7 +68,6 @@ __host__ static void invertPermutationCudaLauncher(const int blocksPerGrid, cons invertPermutationCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); } -BUILD_SINGLE_TEMPLATE(template void invertPermutationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// void invertPermutation(nd4j::LaunchContext* context, const NDArray& input, NDArray& output) { @@ -149,7 +148,7 @@ static void traceCudaLauncher(const int blocksPerGrid, const int threadsPerBlock traceCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diagLen); } -BUILD_SINGLE_TEMPLATE(template void traceCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint diagLen), LIBND4J_TYPES); + /////////////////////////////////////////////////////////////////// void trace(nd4j::LaunchContext* context, const NDArray& input, NDArray& output) { @@ -214,7 +213,6 @@ static void triuBPCudaLauncher(const int blocksPerGrid, const int threadsPerBloc triuBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, diag); } -BUILD_SINGLE_TEMPLATE(template void triuBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag), LIBND4J_TYPES); /////////////////////////////////////////////////////////////////// void triuBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { @@ -280,7 +278,6 @@ static void tileBPCudaLauncher(const int blocksPerGrid, const int threadsPerBloc tileBPCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, globMem); } -BUILD_SINGLE_TEMPLATE(template void tileBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// @@ -526,7 +523,7 @@ static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsP else // means tads using clipByNormBPTadsCuda<<>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast(clipNormVal)); } -BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), LIBND4J_TYPES, FLOAT_TYPES); +BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), FLOAT_TYPES, FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { @@ -547,7 +544,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), nullptr, gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), nullptr, gradI.getSpecialBuffer(), gradI.getSpecialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), nullptr, gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), nullptr, gradI.getSpecialBuffer(), gradI.getSpecialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), FLOAT_TYPES, FLOAT_TYPES); } else { // means tads using @@ -556,7 +553,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.getShapeInfo(), dimensions); const int blocksPerGrid = packX.numberOfTads(); - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.getSpecialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.getSpecialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.getSpecialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.getSpecialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), FLOAT_TYPES, FLOAT_TYPES); } NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); diff --git a/libnd4j/include/play.h b/libnd4j/include/play.h index 1d4ad80dc..ecafe84ea 100644 --- a/libnd4j/include/play.h +++ b/libnd4j/include/play.h @@ -40,7 +40,7 @@ (float, long, long) -BUILD_SINGLE_SELECTOR_THRICE(xType, template class functionName, , DATA_TYPES); +BUILD_SINGLE_TEMPLATE_TWICE(template class functionName, , DATA_TYPES) //BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functionName, (signature), DATA_TYPES, Y_TYPES); diff --git a/libnd4j/include/type_boilerplate.h b/libnd4j/include/type_boilerplate.h index 69ad370b0..bd235726a 100644 --- a/libnd4j/include/type_boilerplate.h +++ b/libnd4j/include/type_boilerplate.h @@ -546,10 +546,13 @@ #ifndef __CLION_IDE__ #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} +#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} + #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } @@ -559,8 +562,10 @@ #else #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) +#define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +#define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) @@ -596,6 +601,12 @@ #define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; #define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) +#define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; +#define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +#define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; +#define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; #define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) @@ -624,6 +635,7 @@ #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) +#define ALL_INDICES nd4j::DataType::INT32, nd4j::DataType::INT64 #define ALL_INTS nd4j::DataType::INT8, nd4j::DataType::UINT8, nd4j::DataType::INT16, nd4j::DataType::UINT16, nd4j::DataType::INT32, nd4j::DataType::UINT32, nd4j::DataType::INT64, nd4j::DataType::UINT64 #define ALL_FLOATS nd4j::DataType::HALF, nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16 diff --git a/libnd4j/include/types/types.h b/libnd4j/include/types/types.h index b11f44c6e..9c8dcb273 100644 --- a/libnd4j/include/types/types.h +++ b/libnd4j/include/types/types.h @@ -76,6 +76,10 @@ (nd4j::DataType::FLOAT32, float), \ (nd4j::DataType::DOUBLE, double) +#define INDEXING_TYPES \ + (nd4j::DataType::INT32, int32_t), \ + (nd4j::DataType::INT64, Nd4jLong) + #define FLOAT_NATIVE \ (nd4j::DataType::FLOAT32, float), \ (nd4j::DataType::DOUBLE, double) diff --git a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp index 5d3ce0ea7..0fa4d687d 100644 --- a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp @@ -34,6 +34,8 @@ public: int dimensionLength = 2; }; +#ifndef __CUDABLAS__ + TEST_F(BroadcastMultiDimTest,MultimDimTest) { shape::TAD *tad = new shape::TAD(); tad->init(inputShapeBuffer,dimensions,dimensionLength); @@ -58,4 +60,6 @@ TEST_F(BroadcastMultiDimTest,MultimDimTest) { } delete tad; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu index 65509a1d4..19f107ea4 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu @@ -452,7 +452,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_20) { ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); } - +/* ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxM_21) { @@ -600,6 +600,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_28) { ASSERT_TRUE(c.equalsTo(&exp)); } + */ ////////////////////////////////////////////////////////////////////////// TEST_F(CudaBasicsTests2, mmulMxV_1) { @@ -918,6 +919,7 @@ TEST_F(CudaBasicsTests2, mmulMxV_18) { } ////////////////////////////////////////////////////////////////////////// +/* TEST_F(CudaBasicsTests2, mmulMxV_19) { const Nd4jLong M = 3; @@ -1150,4 +1152,5 @@ TEST_F(CudaBasicsTests2, mmulDot_4) { nd4j::MmulHelper::mmul(&x, &y, &z); ASSERT_TRUE(z.equalsTo(&exp)); -} \ No newline at end of file +} + */ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 5bf5f8013..3a4552790 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -55,9 +55,9 @@ TEST_F(DataTypesValidationTests, Basic_Test_1) { } TEST_F(DataTypesValidationTests, Basic_Test_2) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); weights.assign(2.0); input.linspace(1); @@ -75,10 +75,10 @@ TEST_F(DataTypesValidationTests, Basic_Test_2) { TEST_F(DataTypesValidationTests, Basic_Test_3) { - auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); - auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); - auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); weights.assign(2.0); input.linspace(1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 8835ecc8e..5375f8fca 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); -// result->printIndexedBuffer("OOOOUUUUTTT"); + //result->printIndexedBuffer("OOOOUUUUTTT"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); NDArray* result = results->at(0); -// result->printBuffer("NonMaxSuppression OUtput2"); + result->printBuffer("NonMaxSuppression OUtput2"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 7f66e9be3..6fe3dfac6 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -729,8 +729,8 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, tensormmul_6) { - NDArray x('c', {1}, {2}); - NDArray y('c', {2,1,2}, {1,2,3,4}); + NDArray x('c', {1}, {2}, nd4j::DataType::FLOAT32); + NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); nd4j::ops::tensormmul op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 53c1e9a99..06d677b27 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -24,6 +24,7 @@ #include #include #include +#include using namespace nd4j; @@ -363,6 +364,77 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { delete result; } +TEST_F(DeclarableOpsTests15, test_concat_column_1) { + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); + + nd4j::ops::concat op; + auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); + ASSERT_EQ(Status::OK(), status); + + z.printIndexedBuffer("z"); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests15, test_concat_large_1) { + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 300}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {2000, 300}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) { + auto row = z.tensorAlongDimension(e, {1}); + + ASSERT_NEAR((float) e, row->e(0), 1e-5f); + + delete row; + } +} + +TEST_F(DeclarableOpsTests15, test_concat_large_2) { + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 5, 20}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {arrays.size(), 5, 20}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) { + auto row = z.tensorAlongDimension(e, {1, 2}); + + ASSERT_NEAR((float) e, row->meanNumber().e(0), 1e-5f); + + delete row; + } +} + TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { auto x0 = NDArrayFactory::create(5); auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); diff --git a/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu new file mode 100644 index 000000000..d7632ace5 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/LaunchContextCudaTests.cu @@ -0,0 +1,125 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace nd4j; +using namespace nd4j::ops; + +class LaunchContextCudaTests : public testing::Test { + // +}; + + +void acquireContext(int threadId, int &deviceId) { + deviceId = AffinityManager::currentDeviceId(); + + nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, deviceId); + + auto lc = LaunchContext::defaultContext(); + nd4j_printf("LC: [%p]\n", lc); + + nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), lc->getCudaStream()); +} + +TEST_F(LaunchContextCudaTests, basic_test_1) { + int deviceA, deviceB; + std::thread threadA(acquireContext, 0, std::ref(deviceA)); + std::thread threadB(acquireContext, 1, std::ref(deviceB)); + + threadA.join(); + threadB.join(); + nd4j_printf("All threads joined\n",""); + + if (AffinityManager::numberOfDevices() > 1) + ASSERT_NE(deviceA, deviceB); +} + +void fillArray(int tid, std::vector &arrays) { + auto array = NDArrayFactory::create_('c', {3, 10}); + nd4j_printf("Array created on device [%i]\n", AffinityManager::currentDeviceId()); + array->assign(tid); + arrays[tid] = array; +} + +TEST_F(LaunchContextCudaTests, basic_test_2) { + std::vector arrays(2); + + std::thread threadA(fillArray, 0, std::ref(arrays)); + std::thread threadB(fillArray, 1, std::ref(arrays)); + + threadA.join(); + threadB.join(); + + for (int e = 0; e < 2; e++) { + auto array = arrays[e]; + ASSERT_EQ(e, array->e(0)); + + delete array; + } +} + +void initAffinity(int tid, std::vector &aff) { + auto affinity = AffinityManager::currentDeviceId(); + aff[tid] = affinity; + nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity); +} + +TEST_F(LaunchContextCudaTests, basic_test_3) { + auto totalThreads = AffinityManager::numberOfDevices() * 4; + nd4j_printf("Total threads: %i\n", totalThreads); + std::vector affinities(totalThreads); + + for (int e = 0; e < totalThreads; e++) { + std::thread thread(initAffinity, e, std::ref(affinities)); + + thread.join(); + } + + std::vector hits(AffinityManager::numberOfDevices()); + std::fill(hits.begin(), hits.end(), 0); + + // we need to make sure all threads were attached to "valid" devices + for (int e = 0; e < totalThreads; e++) { + auto aff = affinities[e]; + ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices()); + + hits[aff]++; + } + + // now we check if all devices got some threads + for (int e = 0; e < AffinityManager::numberOfDevices(); e++) { + ASSERT_GT(hits[e], 0); + } +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index 9151b70bd..aeac06ccb 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -465,6 +465,7 @@ TEST_F(LegacyOpsTests, PowDerivative_1) { ASSERT_TRUE(exp.equalsTo(&x)); } +#ifndef __CUDABLAS__ TEST_F(LegacyOpsTests, reduce3_1) { Nd4jLong yShape[2] = {4,4}; @@ -494,6 +495,8 @@ TEST_F(LegacyOpsTests, reduce3_1) { delete[] xShapeBuffer; } +#endif + TEST_F(LegacyOpsTests, Reduce3_2) { auto x = NDArrayFactory::create('c', {5, 5}); diff --git a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp index 2545bf919..e4c28e9ba 100644 --- a/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PairwiseTests.cpp @@ -32,6 +32,7 @@ public: int dimensionLength = 1; }; +#ifndef __CUDABLAS__ TEST_F(EqualsTest,Eps) { auto val = nd4j::NDArrayFactory::create(0.0f); @@ -45,3 +46,5 @@ TEST_F(EqualsTest,Eps) { val.shapeInfo()); ASSERT_TRUE(val.e(0) < 0.5); } + +#endif diff --git a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp index c6bb7af5b..608ee443f 100644 --- a/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/QuantizationTests.cpp @@ -31,13 +31,17 @@ class QuantizationTests : public testing::Test { }; TEST_F(QuantizationTests, Basic_Test_1) { +#ifndef __CUDABLAS__ auto s = TypeCast::estimateQuantizedSize(10); ASSERT_EQ(18, s); +#endif } TEST_F(QuantizationTests, Basic_Test_2) { +#ifndef __CUDABLAS__ auto s = TypeCast::estimateQuantizedSize(1); ASSERT_EQ(9, s); +#endif } TEST_F(QuantizationTests, Compression_Test_1) { diff --git a/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp b/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp index 0877632d5..4df0f3dc8 100644 --- a/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp @@ -62,7 +62,7 @@ public: std::vector dim = {1, 2, 3}; }; - +#ifndef __CUDABLAS__ TEST_F(EuclideanDistanceTest,Test1) { //int *tadShapeBuffer = shape::computeResultShape(shapeBuffer,dimension,dimensionLength); nd4j::ArrayOptions::setDataType(shapeBuffer, nd4j::DataType::FLOAT32); @@ -152,4 +152,6 @@ TEST_F(ReduceTest,MatrixTest) { delete tad; delete[] xShapeInfo; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp index 7ea44aa17..1f352dd2e 100644 --- a/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TypeCastTests.cpp @@ -32,6 +32,7 @@ public: }; TEST_F(TypeCastTests, Test_Cast_1) { +#ifndef __CUDABLAS__ const int limit = 100; auto src = new double[limit]; auto z = new float[limit]; @@ -51,6 +52,7 @@ TEST_F(TypeCastTests, Test_Cast_1) { delete[] src; delete[] z; delete[] exp; +#endif } TEST_F(TypeCastTests, Test_ConvertDtype_1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java index 15ae181f2..5625db5a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java @@ -34,20 +34,6 @@ public interface AffinityManager { */ Integer getDeviceForCurrentThread(); - /** - * This method returns deviceId for specified thread - * @param thread - * @return - */ - Integer getDeviceForThread(Thread thread); - - /** - * This method returns deviceId for specified threadId - * - * @param threadId - * @return - */ - Integer getDeviceForThread(long threadId); /** * This method returns id of current device for a given INDArray @@ -57,23 +43,6 @@ public interface AffinityManager { */ Integer getDeviceForArray(INDArray array); - /** - * This method attaches specified thread to specified device - * - * @param thread - * @param deviceId - */ - void attachThreadToDevice(Thread thread, Integer deviceId); - - - /** - * This method attaches specified thread (by Id) to specified device - * - * @param threadId java ID of the thread - * @param deviceId - */ - void attachThreadToDevice(long threadId, Integer deviceId); - /** * This method returns number of available devices * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java index 40947e7fc..ad0320825 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java @@ -28,26 +28,6 @@ public abstract class BasicAffinityManager implements AffinityManager { return 0; } - @Override - public Integer getDeviceForThread(Thread thread) { - return 0; - } - - @Override - public Integer getDeviceForThread(long threadId) { - return 0; - } - - @Override - public void attachThreadToDevice(Thread thread, Integer deviceId) { - // no-op - } - - @Override - public void attachThreadToDevice(long threadId, Integer deviceId) { - // no-op - } - @Override public Integer getDeviceForArray(INDArray array) { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index 3b772dba9..b1c4b34ad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -25,6 +25,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp; @@ -43,6 +44,12 @@ public class Concat extends DynamicCustomOp { } + public Concat(int concatDimension, INDArray... arrays) { + super(null, arrays, new INDArray[0]); + this.concatDimension = concatDimension; + addIArgument(concatDimension); + } + public Concat(SameDiff sameDiff, int concatDimension, SDVariable... inputs){ super(null, sameDiff, inputs); addIArgument(concatDimension); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java index e8640bea6..500a1e123 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java @@ -129,13 +129,8 @@ public class AsyncDataSetIterator implements DataSetIterator { if (iterator.resetSupported() && !iterator.hasNext()) this.backedIterator.reset(); - this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, null); + this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, null, deviceId); - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); thread.setDaemon(true); thread.start(); } @@ -229,12 +224,7 @@ public class AsyncDataSetIterator implements DataSetIterator { backedIterator.reset(); shouldWork.set(true); - this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator, null); - - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); + this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator, null, deviceId); thread.setDaemon(true); thread.start(); @@ -391,13 +381,15 @@ public class AsyncDataSetIterator implements DataSetIterator { .policySpill(SpillPolicy.REALLOCATE).build(); private MemoryWorkspace workspace; + private final int deviceId; protected AsyncPrefetchThread(@NonNull BlockingQueue queue, @NonNull DataSetIterator iterator, - @NonNull DataSet terminator, MemoryWorkspace workspace) { + @NonNull DataSet terminator, MemoryWorkspace workspace, int deviceId) { this.queue = queue; this.iterator = iterator; this.terminator = terminator; + this.deviceId = deviceId; this.setDaemon(true); this.setName("ADSI prefetch thread"); @@ -405,6 +397,7 @@ public class AsyncDataSetIterator implements DataSetIterator { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); externalCall(); try { if (useWorkspace) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java index 078db549a..ac20ce66d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java @@ -116,12 +116,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { if (iterator.resetSupported() && !iterator.hasNext()) this.backedIterator.reset(); - this.thread = new AsyncPrefetchThread(buffer, iterator, terminator); - - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); + this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, deviceId); thread.setDaemon(true); thread.start(); @@ -207,12 +202,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { backedIterator.reset(); shouldWork.set(true); - this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator); - - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); + this.thread = new AsyncPrefetchThread(buffer, backedIterator, terminator, deviceId); thread.setDaemon(true); thread.start(); @@ -340,13 +330,15 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { private MemoryWorkspace workspace; + private final int deviceId; + protected AsyncPrefetchThread(@NonNull BlockingQueue queue, - @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator) { + @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) { this.queue = queue; this.iterator = iterator; this.terminator = terminator; - + this.deviceId = deviceId; this.setDaemon(true); this.setName("AMDSI prefetch thread"); @@ -354,6 +346,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); externalCall(); try { if (useWorkspaces) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index a5185ef88..429782c3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -58,6 +58,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.shape.Diag; import org.nd4j.linalg.api.ops.impl.shape.DiagPart; import org.nd4j.linalg.api.ops.impl.shape.Stack; +import org.nd4j.linalg.api.ops.impl.shape.Tile; import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.impl.*; @@ -2557,15 +2558,17 @@ public class Nd4j { public static INDArray read(DataInputStream dis) { val headerShape = BaseDataBuffer.readHeader(dis); - var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle()}, headerShape.getRight()); + var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle().longValue()}, headerShape.getRight()); shapeInformation.read(dis, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getThird()); - DataType type; + val length = Shape.length(shapeInformation); + DataType type = null; DataBuffer data = null; val headerData = BaseDataBuffer.readHeader(dis); try { // current version contains dtype in extras data = CompressedDataBuffer.readUnknown(dis, headerData.getFirst(), headerData.getMiddle(), headerData.getRight()); + type = ArrayOptionsHelper.dataType(shapeInformation.asLong()); } catch (ND4JUnknownDataTypeException e) { // manually setting data type type = headerData.getRight(); @@ -5318,25 +5321,7 @@ public class Nd4j { * @return the tiled ndarray */ public static INDArray tile(INDArray tile, @NonNull int... repeat) { - int d = repeat.length; - long[] shape = ArrayUtil.copy(tile.shape()); - long n = Math.max(tile.length(), 1); - if (d < tile.rank()) { - repeat = Ints.concat(ArrayUtil.nTimes(tile.rank() - d, 1), repeat); - } - for (int i = 0; i < shape.length; i++) { - if (repeat[i] != 1) { - tile = tile.reshape(-1, n).repeat(0, repeat[i]); - } - - long in = shape[i]; - long nOut = in * repeat[i]; - shape[i] = nOut; - n /= Math.max(in, 1); - - } - - return tile.reshape(shape); + return Nd4j.exec(new Tile(new INDArray[]{tile}, new INDArray[]{}, repeat))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java index a38a4a198..26d850366 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java @@ -57,13 +57,12 @@ public class DeallocatorService { log.debug("Starting deallocator thread {}", e + 1); queues[e] = new ReferenceQueue<>(); + int deviceId = e % numDevices; // attaching queue to its own thread - deallocatorThreads[e] = new DeallocatorServiceThread(queues[e], e); + deallocatorThreads[e] = new DeallocatorServiceThread(queues[e], e, deviceId); deallocatorThreads[e].setName("DeallocatorServiceThread_" + e); deallocatorThreads[e].setDaemon(true); - int deviceId = e % numDevices; - Nd4j.getAffinityManager().attachThreadToDevice(deallocatorThreads[e], deviceId); deviceMap.get(deviceId).add(queues[e]); deallocatorThreads[e].start(); @@ -87,16 +86,19 @@ public class DeallocatorService { private final ReferenceQueue queue; private final int threadIdx; public static final String DeallocatorThreadNamePrefix = "DeallocatorServiceThread thread "; + private final int deviceId; - private DeallocatorServiceThread(@NonNull ReferenceQueue queue, int threadIdx) { + private DeallocatorServiceThread(@NonNull ReferenceQueue queue, int threadIdx, int deviceId) { this.queue = queue; this.threadIdx = threadIdx; this.setName(DeallocatorThreadNamePrefix + threadIdx); + this.deviceId = deviceId; setContextClassLoader(null); } @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); boolean canRun = true; long cnt = 0; while (canRun) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index b3885962a..174be9a7d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1175,4 +1175,14 @@ public interface NativeOps { String runFullBenchmarkSuit(boolean printOut); long getCachedMemory(int deviceId); + + OpaqueLaunchContext defaultLaunchContext(); + + Pointer lcScalarPointer(OpaqueLaunchContext lc); + Pointer lcReductionPointer(OpaqueLaunchContext lc); + Pointer lcAllocationPointer(OpaqueLaunchContext lc); + Pointer lcExecutionStream(OpaqueLaunchContext lc); + Pointer lcCopyStream(OpaqueLaunchContext lc); + Pointer lcBlasHandle(OpaqueLaunchContext lc); + Pointer lcSolverHandle(OpaqueLaunchContext lc); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueLaunchContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueLaunchContext.java new file mode 100644 index 000000000..d5f3df5e8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueLaunchContext.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueLaunchContext extends Pointer { + public OpaqueLaunchContext(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java index 6ebdeda6f..df45f85e5 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/Allocator.java @@ -17,8 +17,6 @@ package org.nd4j.jita.allocator; import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationShape; @@ -50,7 +48,7 @@ public interface Allocator { * * @return */ - ExternalContext getDeviceContext(); + CudaContext getDeviceContext(); /** * This methods specifies Mover implementation to be used internally @@ -170,8 +168,6 @@ public interface Allocator { FlowController getFlowController(); - ContextPool getContextPool(); - DataBuffer getConstantBuffer(int[] array); DataBuffer getConstantBuffer(float[] array); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPack.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPack.java deleted file mode 100644 index aee69924e..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/ContextPack.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context; - -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; -import org.apache.commons.lang3.RandomUtils; -import org.nd4j.linalg.jcublas.context.CudaContext; - -import java.util.HashMap; -import java.util.Map; - -/** - * @author raver119@gmail.com - */ -public class ContextPack { - @Getter - @Setter - private Integer deviceId; - @Getter - private int availableLanes; - private Map lanes = new HashMap<>(); - - public ContextPack(int totalLanes) { - availableLanes = totalLanes; - } - - public ContextPack(CudaContext context) { - this.availableLanes = 1; - lanes.put(0, context); - } - - public void addLane(@NonNull Integer laneId, @NonNull CudaContext context) { - lanes.put(laneId, context); - context.setLaneId(laneId); - } - - public CudaContext getContextForLane(Integer laneId) { - return lanes.get(laneId); - } - - public int nextRandomLane() { - if (availableLanes == 1) - return 0; - return RandomUtils.nextInt(0, availableLanes); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/BasicContextPool.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/BasicContextPool.java deleted file mode 100644 index 5e2eb20e6..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/BasicContextPool.java +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context.impl; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.RandomUtils; -import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.context.ContextPack; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.cuda.CUcontext; -import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; -import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; -import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; -import org.nd4j.linalg.jcublas.context.CudaContext; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; - - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Semaphore; - -import org.bytedeco.cuda.cublas.*; -import org.bytedeco.cuda.cusolver.*; -import static org.bytedeco.cuda.global.cublas.*; -import static org.bytedeco.cuda.global.cusolver.*; - -/** - * This is context pool implementation, addressing shared cublas allocations together with shared stream pools - * - * Each context given contains: - * 1. Stream for custom kernel invocations. - * 2. cuBLAS handle tied with separate stream. - * - * @author raver119@gmail.com - */ -@Slf4j -public class BasicContextPool implements ContextPool { - // TODO: number of max threads should be device-dependant - protected static final int MAX_STREAMS_PER_DEVICE = Integer.MAX_VALUE - 1; - - protected volatile Map cuPool = new ConcurrentHashMap<>(); - - protected volatile Map cublasPool = new ConcurrentHashMap<>(); - protected volatile Map solverPool = new ConcurrentHashMap<>(); - - protected volatile Map contextsPool = new ConcurrentHashMap<>(); - - protected volatile Map> contextsForDevices = new ConcurrentHashMap<>(); - - protected Semaphore lock = new Semaphore(1); - - protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - public BasicContextPool() { - - } - - public boolean containsContextForThread(long threadId) { - return contextsPool.containsKey(threadId); - } - - public CudaContext getContextForDevice(Integer deviceId) { - return acquireContextForDevice(deviceId); - } - - @Override - public CudaContext acquireContextForDevice(Integer deviceId) { - /* - We should check, if we have context for this specific thread/device - If we don't have context for this thread - we should stick to one of existent contexts available at pool - */ - Long threadId = Thread.currentThread().getId(); - if (!contextsPool.containsKey(threadId)) { - // we don't have attached context for this thread. we should pick up existing context for target device (if any). - - try { - // this is lockable thing, but since it locks once per thread initialization, performance impact won't be big - lock.acquire(); - - if (!contextsForDevices.containsKey(deviceId)) { - contextsForDevices.put(deviceId, new ConcurrentHashMap()); - } - - // if we hadn't hit MAX_STREAMS_PER_DEVICE limit - we add new stream. Otherwise we use random one. - if (contextsForDevices.get(deviceId).size() < MAX_STREAMS_PER_DEVICE) { - log.debug("Creating new context..."); - CudaContext context = createNewStream(deviceId); - - getDeviceBuffers(context, deviceId); - - if (contextsForDevices.get(deviceId).size() == 0) { - // if we have no contexts created - it's just awesome time to attach cuBLAS handle here - log.debug("Creating new cuBLAS handle for device [{}]...", deviceId); - - //cudaStream_t cublasStream = createNewStream(deviceId).getOldStream(); - - cublasHandle_t handle = createNewCublasHandle(context.getOldStream()); - context.setHandle(handle); - //context.setCublasStream(cublasStream); - - cublasPool.put(deviceId, handle); - - log.debug("Creating new cuSolver handle for device [{}]...", deviceId); - - cudaStream_t solverStream = createNewStream(deviceId).getOldStream(); - - cusolverDnHandle_t solverhandle = createNewSolverHandle(solverStream); - context.setSolverHandle(solverhandle); - context.setSolverStream(solverStream); - - solverPool.put(deviceId, solverhandle); - - } else { - // just pick handle out there - log.debug("Reusing blas here..."); - cublasHandle_t handle = cublasPool.get(deviceId); - context.setHandle(handle); - - log.debug("Reusing solver here..."); - cusolverDnHandle_t solverHandle = solverPool.get(deviceId); - context.setSolverHandle(solverHandle); - - // TODO: actually we don't need this anymore - // cudaStream_t cublasStream = new cudaStream_t(); - // JCublas2.cublasGetStream(handle, cublasStream); - // context.setCublasStream(cublasStream); - } - - // we need this sync to finish memset - context.syncOldStream(); - - contextsPool.put(threadId, context); - contextsForDevices.get(deviceId).put(contextsForDevices.get(deviceId).size(), context); - - return context; - } else { - Integer rand = RandomUtils.nextInt(0, MAX_STREAMS_PER_DEVICE); - log.debug("Reusing context: " + rand); - - nativeOps.setDevice(deviceId); - - CudaContext context = contextsForDevices.get(deviceId).get(rand); - - contextsPool.put(threadId, context); - return context; - } - - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - lock.release(); - } - } - - return contextsPool.get(threadId); - } - - @Override - public void releaseContext(CudaContext context) { - // no-op - } - - protected CudaContext createNewStream(Integer deviceId) { - log.trace("Creating new stream for thread: [{}], device: [{}]...", Thread.currentThread().getId(), deviceId); - nativeOps.setDevice(deviceId); - - CudaContext context = new CudaContext(); - context.initOldStream(); - - return context; - } - - protected cublasHandle_t createNewCublasHandle() { - cublasContext pointer = new cublasContext(); - int result = cublasCreate_v2(pointer); - if (result != 0) { - throw new IllegalStateException("Can't create new cuBLAS handle! cuBLAS errorCode: [" + result + "]"); - } - - cublasHandle_t handle = new cublasHandle_t(pointer); - - return handle; - } - - - protected cublasHandle_t createNewCublasHandle(cudaStream_t stream) { - return createNewCublasHandle(); - } - - protected cusolverDnHandle_t createNewSolverHandle() { - cusolverDnContext pointer = new cusolverDnContext(); - int result = cusolverDnCreate(pointer); - if (result != 0) { - throw new IllegalStateException("Can't create new cuBLAS handle! cusolverDn errorCode: [" + result - + "] from cusolverDnCreate()"); - } - - cusolverDnHandle_t handle = new cusolverDnHandle_t(pointer); - - return handle; - } - - protected cusolverDnHandle_t createNewSolverHandle(cudaStream_t stream) { - return createNewSolverHandle(); - } - - protected CUcontext createNewContext(Integer deviceId) { - /* - log.debug("Creating new CUcontext..."); - CUdevice device = new CUdevice(); - CUcontext context = new CUcontext(); - - //JCuda.cudaSetDevice(deviceId); - - - int result = cuDeviceGet(device, deviceId); - if (result != CUresult.CUDA_SUCCESS) { - throw new RuntimeException("Failed to setDevice on driver"); - } - - result = cuCtxCreate(context, 0, device); - if (result != CUresult.CUDA_SUCCESS) { - throw new RuntimeException("Failed to create context on driver"); - } - - return context; - */ - return null; - } - - /** - * This methods reset everything in pool, forcing recreation of all streams - * - * PLEASE NOTE: This is debugging-related method, and should NOT be used in real tasks - */ - public synchronized void resetPool(int deviceId) { - /* - for (CUcontext cuContext: cuPool.values()) { - log.debug("Destroying context: " + cuContext); - JCudaDriver.cuCtxDestroy(cuContext); - } - - cuPool.clear(); - contextsForDevices.clear(); - contextsPool.clear(); - cublasPool.clear(); - - solverPool.clear(); - - acquireContextForDevice(deviceId); - */ - } - - public CUcontext getCuContextForDevice(Integer deviceId) { - return cuPool.get(deviceId); - } - - /** - * This method is used to allocate - * @param context - * @param deviceId - */ - protected void getDeviceBuffers(CudaContext context, int deviceId) { - NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); //((CudaExecutioner) Nd4j.getExecutioner()).getNativeOps(); - - // we hardcode sizeOf to sizeOf(double) - int sizeOf = 8; - - val reductionPointer = nativeOps.mallocDevice(16384 * sizeOf, deviceId, 0); - if (reductionPointer == null) - throw new IllegalStateException("Can't allocate [DEVICE] reduction buffer memory!"); - - nativeOps.memsetAsync(reductionPointer, 0, 16384 * sizeOf, 0, context.getOldStream()); - - context.syncOldStream(); - - val allocationPointer = nativeOps.mallocDevice(16384 * sizeOf, deviceId, 0); - if (allocationPointer == null) - throw new IllegalStateException("Can't allocate [DEVICE] allocation buffer memory!"); - - val scalarPointer = nativeOps.mallocHost(sizeOf, 0); - if (scalarPointer == null) - throw new IllegalStateException("Can't allocate [HOST] scalar buffer memory!"); - - context.setBufferScalar(scalarPointer); - context.setBufferAllocation(allocationPointer); - context.setBufferReduction(reductionPointer); - - val specialPointer = nativeOps.mallocDevice(16384 * sizeOf, deviceId, 0); - if (specialPointer == null) - throw new IllegalStateException("Can't allocate [DEVICE] special buffer memory!"); - - nativeOps.memsetAsync(specialPointer, 0, 16384 * sizeOf, 0, context.getOldStream()); - - context.setBufferSpecial(specialPointer); - } - - public ContextPack acquireContextPackForDevice(Integer deviceId) { - return new ContextPack(acquireContextForDevice(deviceId)); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/LimitedContextPool.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/LimitedContextPool.java deleted file mode 100644 index 20f8adc21..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/LimitedContextPool.java +++ /dev/null @@ -1,265 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context.impl; - -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import lombok.var; -import org.apache.commons.lang3.RandomUtils; -import org.nd4j.jita.allocator.context.ContextPack; -import org.nd4j.jita.allocator.garbage.DeallocatableThread; -import org.nd4j.jita.allocator.garbage.GarbageResourceReference; -import org.nd4j.jita.allocator.impl.AtomicAllocator; -import org.nd4j.jita.allocator.pointers.CudaPointer; -import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; -import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.linalg.api.memory.Deallocatable; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.jcublas.context.CudaContext; -import org.nd4j.nativeblas.NativeOps; -import org.nd4j.nativeblas.NativeOpsHolder; - -import java.lang.ref.ReferenceQueue; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.LockSupport; - -/** - * @author raver119@gmail.com - */ -@Slf4j -public class LimitedContextPool extends BasicContextPool { - - // pool of free contexts - protected Map> pool = new HashMap<>(); - - // pool of used pools - protected Map acquired = new ConcurrentHashMap<>(); - //protected AtomicInteger currentPoolSize = new AtomicInteger(0); - protected List devicePoolSizes = new ArrayList<>(); - protected Map> queueMap = new HashMap<>(); - - protected ThreadLocal threadHooks = new ThreadLocal<>(); - - public LimitedContextPool() { - - int perDevicePool = CudaEnvironment.getInstance().getConfiguration().getPoolSize(); - -/* - for (int i = 0; i < 4; i++) { - val queue = new ReferenceQueue(); - val collector = new ResourceGarbageCollectorThread(i, queue); - collector.start(); - - collectors.put(i, collector); - queueMap.put(i, queue); - } -*/ - fillPoolWithResources(perDevicePool, false); - } - - protected void addResourcesToPool(int numResources) { - int device = AtomicAllocator.getInstance().getDeviceId(); - - val handle = createNewCublasHandle(); - for (int cnt = 0; cnt < numResources; cnt++) { - val context = createNewStream(device); - context.initOldStream(); - getDeviceBuffers(context, device); - context.setHandle(handle); - - context.syncOldStream(); - - pool.get(device).add(context); - } - } - - protected synchronized void fillPoolWithResources(int numResources, boolean restoreDevice) { - List devices = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices(); - - int cDevice = 0; - if (restoreDevice) { - cDevice = AtomicAllocator.getInstance().getDeviceId(); - } - - NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - for (Integer device : devices) { - nativeOps.setDevice(device); - pool.put(device, new LinkedBlockingQueue()); - devicePoolSizes.add(new AtomicInteger(numResources)); - - val handle = createNewCublasHandle(); - val solverHandle = createNewSolverHandle(); - for (int cnt = 0; cnt < numResources; cnt++) { - val context = createNewStream(device); - context.initOldStream(); - getDeviceBuffers(context, device); - context.setHandle(handle); - context.setSolverHandle(solverHandle); - - context.syncOldStream(); - - pool.get(device).add(context); - } - - - } - - if (restoreDevice) { - nativeOps.setDevice(cDevice); - } - } - - public void removeAcquired() { - val threadIdx = Thread.currentThread().getId(); - acquired.remove(threadIdx); - } - - @Override - public CudaContext acquireContextForDevice(Integer deviceId) { - val threadIdx = Thread.currentThread().getId(); - var context = acquired.get(threadIdx); - if (context != null && deviceId == context.getDeviceId()) { - return context; - } - - //log.info("Setting device to {}", deviceId); - nativeOps.setDevice(deviceId); - context = pool.get(deviceId).poll(); - if (context != null) { - //val reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue()); - //context.attachReference(reference); - context.setDeviceId(deviceId); - context.setThreadId(threadIdx); - val hook = new DeallocatableThread(Thread.currentThread(), context); - threadHooks.set(hook); - Nd4j.getDeallocatorService().pickObject(hook); - - - acquired.put(threadIdx, context); - return context; - } else { - - do { - try { - Nd4j.getMemoryManager().invokeGc(); - - context = pool.get(deviceId).poll(1, TimeUnit.SECONDS); - if (context != null) { - //val reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue()); - //context.attachReference(reference); - context.setDeviceId(deviceId); - context.setThreadId(threadIdx); - val hook = new DeallocatableThread(Thread.currentThread(), context); - threadHooks.set(hook); - Nd4j.getDeallocatorService().pickObject(hook); - - acquired.put(threadIdx, context); - } else { - val currentPoolSize = devicePoolSizes.get(deviceId); - synchronized (currentPoolSize) { - if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) { - addResourcesToPool(16); - - // there's possible race condition, but we don't really care - currentPoolSize.addAndGet(16); - log.warn("Initial pool size: {}; Current pool size: {}", CudaEnvironment.getInstance().getConfiguration().getPoolSize(), currentPoolSize.get()); - } else { - log.warn("Can't allocate new context, sleeping..."); - - Nd4j.getMemoryManager().invokeGc(); - try { - Thread.sleep(500); - } catch (Exception e) { - // - } - } - } - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } while (context == null); - - return context; - } - } - - @Override - @Deprecated - public ContextPack acquireContextPackForDevice(Integer deviceId) { - return new ContextPack(acquireContextForDevice(deviceId)); - } - - @Override - public CudaContext getContextForDevice(Integer deviceId) { - return acquireContextForDevice(deviceId); - } - - @Override - public void releaseContext(CudaContext context) { - val threadIdx = context.getThreadId(); - val deviceId = context.getDeviceId(); - - context.setThreadId(-1); - - acquired.remove(threadIdx); - pool.get(deviceId).add(context); - } - - /* - private class ResourceGarbageCollectorThread extends Thread implements Runnable { - private final ReferenceQueue queue; - - public ResourceGarbageCollectorThread(int threadId, @NonNull ReferenceQueue queue) { - this.queue = queue; - this.setDaemon(true); - this.setName("ResourceGC thread " + threadId); - } - - @Override - public void run() { - while (true) { - GarbageResourceReference reference = (GarbageResourceReference) queue.poll(); - if (reference != null) { - CudaContext context = reference.getContext(); - val threadId = reference.getThreadId(); - val deviceId = reference.getDeviceId(); - - // there's a chance context was already released - if (context.getThreadId() != threadId) - continue; - - pool.get(deviceId).add(context); - acquired.remove(threadId); - } else { - LockSupport.parkNanos(500000L); - } - } - } - } - */ -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/PackedContextPool.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/PackedContextPool.java deleted file mode 100644 index 55a72f5c6..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/context/impl/PackedContextPool.java +++ /dev/null @@ -1,109 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.context.impl; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.jita.allocator.context.ContextPack; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; -import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; -import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; -import org.nd4j.jita.conf.CudaEnvironment; -import org.nd4j.linalg.jcublas.context.CudaContext; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * @author raver119@gmail.com - */ -@Deprecated -@Slf4j -public class PackedContextPool extends BasicContextPool implements ContextPool { - - protected static final int LANES_PER_THREAD = - CudaEnvironment.getInstance().getConfiguration().getCommandLanesNumber(); - - private volatile Map contextsPool = new ConcurrentHashMap<>(); - - @Override - public CudaContext acquireContextForDevice(Integer deviceId) { - return acquireContextPackForDevice(deviceId).getContextForLane(0); - } - - @Override - public ContextPack acquireContextPackForDevice(Integer deviceId) { - Long threadId = Thread.currentThread().getId(); - if (!contextsPool.containsKey(threadId)) { - try { - lock.acquire(); - - ContextPack pack = new ContextPack(LANES_PER_THREAD); - for (int c = 0; c < LANES_PER_THREAD; c++) { - CudaContext context = createNewStream(deviceId); - - getDeviceBuffers(context, deviceId); - - if (cublasPool.get(deviceId) == null) { - // if we have no contexts created - it's just awesome time to attach cuBLAS handle here - log.debug("Creating new cuBLAS handle for device [{}]", deviceId); - - //cudaStream_t cublasStream = createNewStream(deviceId).getOldStream(); - - cublasHandle_t handle = createNewCublasHandle(context.getOldStream()); - context.setHandle(handle); - //context.setCublasStream(cublasStream); - - cublasPool.put(deviceId, handle); - - log.debug("Creating new cuSolver handle for device [{}]...", deviceId); - - cudaStream_t solverStream = createNewStream(deviceId).getOldStream(); - - cusolverDnHandle_t solverhandle = createNewSolverHandle(solverStream); - context.setSolverHandle(solverhandle); - context.setSolverStream(solverStream); - - solverPool.put(deviceId, solverhandle); - - } else { - // just pick handle out there - log.debug("Reusing cuBLAS handle for device [{}]", deviceId); - cublasHandle_t handle = cublasPool.get(deviceId); - context.setHandle(handle); - - log.debug("Reusing solver here..."); - cusolverDnHandle_t solverHandle = solverPool.get(deviceId); - context.setSolverHandle(solverHandle); - } - - pack.addLane(c, context); - } - - contextsPool.put(threadId, pack); - - - } catch (Exception e) { - throw new RuntimeException(e); - } finally { - lock.release(); - } - } - - return contextsPool.get(threadId); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/DeallocatableThread.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/DeallocatableThread.java deleted file mode 100644 index b850f7836..000000000 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/garbage/DeallocatableThread.java +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.jita.allocator.garbage; - -import org.nd4j.linalg.api.memory.Deallocatable; -import org.nd4j.linalg.api.memory.Deallocator; -import org.nd4j.linalg.jcublas.context.CudaContext; - -/** - * This class enables Thread tracking via DeallocatorService - * @author raver119@gmail.com - */ -public class DeallocatableThread implements Deallocatable { - private long threadId; - private CudaContext context; - - public DeallocatableThread(Thread thread, CudaContext context) { - this.threadId = thread.getId(); - this.context = context; - } - - @Override - public String getUniqueId() { - return "thread_" + threadId; - } - - @Override - public Deallocator deallocator() { - return new ContextDeallocator(context); - } - - @Override - public int targetDevice() { - return context.getDeviceId(); - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index bbefcb0fc..ad4cad0b0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -22,8 +22,6 @@ import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; import org.nd4j.jita.allocator.enums.Aggressiveness; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.garbage.GarbageBufferReference; @@ -226,7 +224,7 @@ public class AtomicAllocator implements Allocator { * @return */ @Override - public ExternalContext getDeviceContext() { + public CudaContext getDeviceContext() { // FIXME: proper lock avoidance required here return memoryHandler.getDeviceContext(); } @@ -290,7 +288,7 @@ public class AtomicAllocator implements Allocator { } public Pointer getPointer(DataBuffer buffer) { - return memoryHandler.getDevicePointer(buffer, (CudaContext) getDeviceContext().getContext()); + return memoryHandler.getDevicePointer(buffer, getDeviceContext()); } /** @@ -1072,11 +1070,6 @@ public class AtomicAllocator implements Allocator { return memoryHandler.getFlowController(); } - @Override - public ContextPool getContextPool() { - return memoryHandler.getContextPool(); - } - @Override public DataBuffer getConstantBuffer(int[] array) { return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.INT); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 0655e7cb1..77c709487 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -68,101 +68,9 @@ public class CudaAffinityManager extends BasicAffinityManager { */ @Override public Integer getDeviceForCurrentThread() { - return getDeviceForThread(Thread.currentThread().getId()); + return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice(); } - /** - * This method returns deviceId for given thread. - * - * If no device was assigned to this thread before this call, it'll be assinged here. - * @param thread - * @return - */ - @Override - public Integer getDeviceForThread(Thread thread) { - return getDeviceForThread(thread.getId()); - } - - /** - * This method returns deviceId for given thread, identified by threadId - * - * If no device was assigned to this thread before this call, it'll be assinged here. - * - * @param threadId - * @return - */ - @Override - public Integer getDeviceForThread(long threadId) { - if (getNumberOfDevices() == 1) - return 0; - - Integer aff = affinityMap.get(threadId); - - if (aff == null) { - Integer deviceId = getNextDevice(threadId); - affinityMap.put(threadId, deviceId); - affiliated.set(new AtomicBoolean(false)); - - if (threadId == Thread.currentThread().getId()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); - //logger.error("setDevice({}) called for thread {}", deviceId, Thread.currentThread().getName()); - affiliated.get().set(true); - } - - return deviceId; - } else { - - if (threadId == Thread.currentThread().getId()) { - if (affiliated.get() == null) - affiliated.set(new AtomicBoolean(false)); - - if (!affiliated.get().get()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(aff); - //logger.error("SCARY setDevice({}) called for thread {}", aff, threadId); - affiliated.get().set(true); - return aff; - } - } - - return aff; - } -/* - - - return affinityMap.get(threadId); -*/ - //return 0; - } - - /** - * This method pairs specified thread & device - * - * @param thread - * @param deviceId - */ - @Override - public void attachThreadToDevice(Thread thread, Integer deviceId) { - attachThreadToDevice(thread.getId(), deviceId); - } - - /** - * This method pairs specified thread & device - * - * @param threadId - * @param deviceId - */ - @Override - public void attachThreadToDevice(long threadId, Integer deviceId) { - val t = Thread.currentThread(); - String name = "N/A"; - if (t.getId() == threadId) - name = t.getName(); - - List devices = new ArrayList<>(CudaEnvironment.getInstance().getConfiguration().getAvailableDevices()); - logger.trace("Manually mapping thread [{} - {}] to device [{}], out of [{}] devices...", threadId, - name, deviceId, devices.size()); - affinityMap.put(threadId, deviceId); - } /** * This method returns device id available. Round-robin balancing used here. @@ -275,14 +183,13 @@ public class CudaAffinityManager extends BasicAffinityManager { val empty = array.isEmpty(); // we use this call to get device memory updated - AtomicAllocator.getInstance().getPointer(array, (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()); + AtomicAllocator.getInstance().getPointer(array, AtomicAllocator.getInstance().getDeviceContext()); int currentDeviceId = getDeviceForCurrentThread(); if (currentDeviceId != deviceId.intValue()) { Nd4j.getMemoryManager().releaseCurrentContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); - attachThreadToDevice(Thread.currentThread().getId(), deviceId); + unsafeSetDevice(deviceId); } @@ -292,8 +199,7 @@ public class CudaAffinityManager extends BasicAffinityManager { if (currentDeviceId != deviceId.intValue()) { Nd4j.getMemoryManager().releaseCurrentContext(); - attachThreadToDevice(Thread.currentThread().getId(), currentDeviceId); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(currentDeviceId); + unsafeSetDevice(currentDeviceId); } @@ -312,11 +218,11 @@ public class CudaAffinityManager extends BasicAffinityManager { if (buffer == null) return null; - int currentDeviceId = AtomicAllocator.getInstance().getDeviceId(); + int currentDeviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + if (currentDeviceId != deviceId) { Nd4j.getMemoryManager().releaseCurrentContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); - Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), deviceId); + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); } DataBuffer dstBuffer = Nd4j.createBuffer(buffer.dataType(), buffer.length(), false); @@ -324,8 +230,7 @@ public class CudaAffinityManager extends BasicAffinityManager { if (currentDeviceId != deviceId) { Nd4j.getMemoryManager().releaseCurrentContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(currentDeviceId); - Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), currentDeviceId); + Nd4j.getAffinityManager().unsafeSetDevice(currentDeviceId); } return dstBuffer; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java index 54e5df7e6..5548d854a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java @@ -143,7 +143,7 @@ public class ProtectedCudaConstantHandler implements ConstantHandler { AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, deviceId, requiredMemoryBytes); long currentOffset = constantOffsets.get(deviceId).get(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) { if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java index 9a8feeb0b..07cad5269 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java @@ -72,7 +72,7 @@ public class SynchronousFlowController implements FlowController { public void synchronizeToHost(AllocationPoint point) { if (!point.isActualOnHostSide()) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); if (!point.isConstant()) waitTillFinished(point); @@ -102,7 +102,7 @@ public class SynchronousFlowController implements FlowController { if (!point.isActualOnDeviceSide()) { if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); long perfD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -135,7 +135,7 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareActionAllWrite(INDArray... operands) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); val cId = allocator.getDeviceId(); for (INDArray operand : operands) { @@ -168,7 +168,7 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareAction(INDArray result, INDArray... operands) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); val cId = allocator.getDeviceId(); @@ -290,7 +290,7 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) { - val context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); if (result != null) { result.acquireLock(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java index 4d11a6564..36d8e05fb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java @@ -19,8 +19,6 @@ package org.nd4j.jita.handler; import com.google.common.collect.Table; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AllocationShape; @@ -295,7 +293,7 @@ public interface MemoryHandler { * This method returns ExternalContext wrapper (if applicable) * @return */ - ExternalContext getDeviceContext(); + CudaContext getDeviceContext(); void registerAction(CudaContext context, INDArray result, INDArray... operands); @@ -306,8 +304,4 @@ public interface MemoryHandler { boolean promoteObject(DataBuffer buffer); void relocateObject(DataBuffer buffer); - - ContextPool getContextPool(); - - } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 02374315b..e35628728 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -25,10 +25,6 @@ import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.concurrency.DeviceAllocationsTracker; -import org.nd4j.jita.allocator.context.ContextPool; -import org.nd4j.jita.allocator.context.ExternalContext; -import org.nd4j.jita.allocator.context.impl.LimitedContextPool; -import org.nd4j.jita.allocator.context.impl.PackedContextPool; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.enums.CudaConstants; import org.nd4j.jita.allocator.impl.AllocationPoint; @@ -37,6 +33,9 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.MemoryTracker; import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.PointersPair; +import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; +import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; +import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; import org.nd4j.jita.allocator.utils.AllocationUtils; import org.nd4j.jita.conf.Configuration; import org.nd4j.jita.conf.CudaEnvironment; @@ -99,8 +98,6 @@ public class CudaZeroHandler implements MemoryHandler { private final AtomicBoolean wasInitialised = new AtomicBoolean(false); - private final ContextPool contextPool; - @Getter private final MemoryProvider memoryProvider; @@ -142,7 +139,6 @@ public class CudaZeroHandler implements MemoryHandler { switch (configuration.getExecutionModel()) { case SEQUENTIAL: { this.flowController = new GridFlowController(); - this.contextPool = new LimitedContextPool(); } break; default: @@ -222,7 +218,7 @@ public class CudaZeroHandler implements MemoryHandler { boolean initialize) { long reqMemory = AllocationUtils.getRequiredMemory(shape); - CudaContext context = getCudaContext(); + val context = getCudaContext(); switch (targetMode) { case HOST: { if (MemoryTracker.getInstance().getActiveHostAmount() + reqMemory >= configuration.getMaximumZeroAllocation()) { @@ -1158,8 +1154,8 @@ public class CudaZeroHandler implements MemoryHandler { * @return */ @Override - public ExternalContext getDeviceContext() { - return new ExternalContext(getCudaContext()); + public CudaContext getDeviceContext() { + return getCudaContext(); } /** @@ -1167,30 +1163,20 @@ public class CudaZeroHandler implements MemoryHandler { * @return */ public CudaContext getCudaContext() { - // FIXME: remove this before release - Integer deviceId = getDeviceId(); - return contextPool.acquireContextForDevice(deviceId); - } + val lc = nativeOps.defaultLaunchContext(); + // TODO: maybe make ThreadLocal cache for context? - /** - * This method does initialization for thread. - * - * - * @param threadId - */ - protected void initCudaContextForThread(Long threadId) { - - // we set device to be used prior to stream creation - - nativeOps.setDevice(getDeviceId()); - - CudaContext context = new CudaContext(); - context.initHandle(); - context.initOldStream(); - context.initStream(); - context.associateHandle(); - //contextPool.put(threadId, context); + return CudaContext.builder() + .bufferScalar(nativeOps.lcScalarPointer(lc)) + .bufferReduction(nativeOps.lcReductionPointer(lc)) + .bufferAllocation(nativeOps.lcAllocationPointer(lc)) + .bufferSpecial(nativeOps.lcScalarPointer(lc)) + .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) + .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) + .cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc))) + .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) + .build(); } /** @@ -1227,11 +1213,4 @@ public class CudaZeroHandler implements MemoryHandler { public FlowController getFlowController() { return flowController; } - - @Override - public ContextPool getContextPool() { - return contextPool; - } - - } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java index 3263bf291..da36da6db 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/memory/CudaMemoryManager.java @@ -19,7 +19,6 @@ package org.nd4j.jita.memory; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; -import org.nd4j.jita.allocator.context.impl.LimitedContextPool; import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; @@ -79,7 +78,7 @@ public class CudaMemoryManager extends BasicMemoryManager { throw new RuntimeException("Failed to allocate " + bytes + " bytes from DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "] memory"); if (initialize) { - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); int i = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(ptr, 0, bytes, 0, context.getSpecialStream()); if (i == 0) @@ -168,7 +167,7 @@ public class CudaMemoryManager extends BasicMemoryManager { */ @Override public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (dstBuffer instanceof CompressedDataBuffer && !(srcBuffer instanceof CompressedDataBuffer)) { @@ -258,7 +257,7 @@ public class CudaMemoryManager extends BasicMemoryManager { AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(array); if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + CudaContext context = AtomicAllocator.getInstance().getDeviceContext(); NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(AtomicAllocator.getInstance().getPointer(array, context),0, array.data().length() * Nd4j.sizeOfDataType(array.data().dataType()),0, context.getOldStream()); // we also memset host pointer @@ -289,20 +288,6 @@ public class CudaMemoryManager extends BasicMemoryManager { @Override public void releaseCurrentContext() { - // gettting context for this thread - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - - if (context == null) - return; - - // we dont want any remnaints below this line - context.syncOldStream(); - context.syncSpecialStream(); - - val pool = AtomicAllocator.getInstance().getContextPool(); - - // push it back to pool - pool.releaseContext(context); - ((LimitedContextPool) pool).removeAcquired(); + throw new UnsupportedOperationException("Not implemented yet"); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index 5e1d2eeaf..c901cdd67 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -177,7 +177,7 @@ public class CudaWorkspace extends Nd4jWorkspace { log.info("Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements, prevOffset, deviceOffset.get(), currentSize.get(), ptr.address()); if (initialize) { - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); int ret = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(ptr, 0, requiredMemory, 0, context.getSpecialStream()); if (ret == 0) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java index 8b37e8ead..cd0356d18 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArray.java @@ -570,7 +570,7 @@ public class JCublasNDArray extends BaseNDArray { //Nd4j.getExecutioner().commit(); AtomicAllocator allocator = AtomicAllocator.getInstance(); - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = (CudaContext) allocator.getDeviceContext(); AllocationPoint srcPoint = allocator.getAllocationPoint(this); AllocationPoint dstPoint = allocator.getAllocationPoint(ret); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index b0db6ec50..44c361d87 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataTypeEx; import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.memory.enums.MemoryKind; +import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; @@ -410,6 +411,10 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + return Nd4j.exec(new Concat(dimension, toConcat))[0]; + + // legacy implementation +/* boolean allScalars = true; var outputShape = ArrayUtil.copy(toConcat[0].shape()); @@ -531,6 +536,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { return ret; //return super.concat(dimension, toConcat); + */ } @@ -546,7 +552,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { PointerPointer dataPointers = new PointerPointer(toConcat.length); AtomicAllocator allocator = AtomicAllocator.getInstance(); - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); int sumAlongDim = 0; @@ -783,10 +789,10 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { Nd4j.getExecutioner().commit(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer dataPointers = new PointerPointer(arrays.length); - PointerPointer extras = new PointerPointer(null, // not used + val dataPointers = new PointerPointer(arrays.length); + val extras = new PointerPointer(null, // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), new CudaPointer(1) ); for (int i = 0; i < arrays.length; i++) { @@ -899,10 +905,10 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { */ long len = target == null ? arrays[0].lengthLong() : target.lengthLong(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer dataPointers = new PointerPointer(arrays.length); - PointerPointer extras = new PointerPointer(null, // not used + val dataPointers = new PointerPointer(arrays.length); + val extras = new PointerPointer(null, // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), new CudaPointer(1) ); for (int i = 0; i < arrays.length; i++) { @@ -1249,7 +1255,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, long length) { - val stream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream(); + val stream = AtomicAllocator.getInstance().getDeviceContext().getOldStream(); val p = new PointerPointer<>(new Pointer[]{null, stream}); @@ -1262,7 +1268,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { Pointer dstPtr = null; long size = 0; long ssize = 0; - val stream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream(); + val stream = AtomicAllocator.getInstance().getDeviceContext().getOldStream(); if (buffer instanceof CompressedDataBuffer) { // compressing size = ((CompressedDataBuffer) buffer).getCompressionDescriptor().getCompressedLength(); @@ -1291,7 +1297,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { @Override public void convertDataEx(DataTypeEx typeSrc, DataBuffer source, DataTypeEx typeDst, DataBuffer target) { - val stream = ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream(); + val stream = AtomicAllocator.getInstance().getDeviceContext().getOldStream(); Pointer srcPtr = null; Pointer dstPtr = null; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java index 592e9f3f6..74a8fc99c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java @@ -75,7 +75,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -142,7 +142,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -214,7 +214,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -330,7 +330,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -439,7 +439,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -523,7 +523,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -656,7 +656,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -766,7 +766,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -853,7 +853,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = (CudaContext) allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); @@ -928,7 +928,7 @@ public class JcublasLapack extends BaseLapack { ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); // Get context for current thread - CudaContext ctx = (CudaContext) allocator.getDeviceContext().getContext(); + val ctx = allocator.getDeviceContext(); // setup the solver handles for cuSolver calls cusolverDnHandle_t handle = ctx.getSolverHandle(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java index 7009bfbaa..d4efb87b4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java @@ -100,7 +100,7 @@ public class JcublasLevel1 extends BaseLevel1 { val xCPointer = new CublasPointer(X, ctx); val yCPointer = new CublasPointer(Y, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); val cctx = new cublasContext(handle); synchronized (handle) { @@ -144,7 +144,7 @@ public class JcublasLevel1 extends BaseLevel1 { val xCPointer = new CublasPointer(X, ctx); val yCPointer = new CublasPointer(Y, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); synchronized (handle) { val cctx = new cublasContext(handle); cublasSetStream_v2(cctx, new CUstream_st(ctx.getCublasStream())); @@ -177,7 +177,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer cAPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -235,7 +235,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer cAPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -276,7 +276,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -306,7 +306,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -337,7 +337,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -361,7 +361,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -395,7 +395,7 @@ public class JcublasLevel1 extends BaseLevel1 { // CublasPointer xAPointer = new CublasPointer(X, ctx); // CublasPointer xBPointer = new CublasPointer(Y, ctx); - // cublasHandle_t handle = ctx.getHandle(); + // cublasHandle_t handle = ctx.getCublasHandle(); ((CudaExecutioner) Nd4j.getExecutioner()).exec(new Axpy(X, Y, Y, alpha)); @@ -424,7 +424,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -446,7 +446,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); CublasPointer yCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -540,7 +540,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -563,7 +563,7 @@ public class JcublasLevel1 extends BaseLevel1 { CublasPointer xCPointer = new CublasPointer(X, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java index 652d7b328..05f33ac3e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java @@ -62,7 +62,7 @@ public class JcublasLevel2 extends BaseLevel2 { CublasPointer cBPointer = new CublasPointer(X, ctx); CublasPointer cCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -134,7 +134,7 @@ public class JcublasLevel2 extends BaseLevel2 { CublasPointer cBPointer = new CublasPointer(X, ctx); CublasPointer cCPointer = new CublasPointer(Y, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java index 99a9718c8..7f8f9bb51 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java @@ -72,7 +72,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer cBPointer = new CublasPointer(B, ctx); CublasPointer cCPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -118,7 +118,7 @@ public class JcublasLevel3 extends BaseLevel3 { val cBPointer = new CublasPointer(B, ctx); val cCPointer = new CublasPointer(C, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -144,7 +144,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer bPointer = new CublasPointer(B, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -169,7 +169,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -206,7 +206,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer bPointer = new CublasPointer(B, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -236,7 +236,7 @@ public class JcublasLevel3 extends BaseLevel3 { val cBPointer = new CublasPointer(B, ctx); val cCPointer = new CublasPointer(C, ctx); - val handle = ctx.getHandle(); + val handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -261,7 +261,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer bPointer = new CublasPointer(B, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -286,7 +286,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -311,7 +311,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer bPointer = new CublasPointer(B, ctx); CublasPointer cPointer = new CublasPointer(C, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -336,7 +336,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer bPointer = new CublasPointer(B, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); @@ -362,7 +362,7 @@ public class JcublasLevel3 extends BaseLevel3 { CublasPointer aPointer = new CublasPointer(A, ctx); CublasPointer bPointer = new CublasPointer(B, ctx); - cublasHandle_t handle = ctx.getHandle(); + cublasHandle_t handle = ctx.getCublasHandle(); synchronized (handle) { cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream())); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 8d29e2b7b..1bec19dd0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -121,7 +121,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda Nd4j.getDeallocatorService().pickObject(this); // now we're - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); val perfD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -1522,7 +1522,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda lazyAllocateHostPointer(); } - val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java index bee68b3bf..19e8f8df6 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/compression/CudaThreshold.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.jcublas.compression; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.apache.commons.math3.util.FastMath; import org.bytedeco.javacpp.*; import org.nd4j.compression.impl.AbstractCompressor; @@ -118,7 +119,7 @@ public class CudaThreshold extends AbstractCompressor { DataBuffer result = Nd4j.createBuffer(type, originalLength, false); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); PointerPointer extras = new PointerPointer(32).put(1, context.getOldStream()); @@ -139,7 +140,7 @@ public class CudaThreshold extends AbstractCompressor { int numThreads = 1024; int numBlocks = (int) (buffer.length() / numThreads + (buffer.length() % numThreads == 0 ? 0 : 1)); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); DataBuffer blocksBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(numBlocks+1, true) : Nd4j.getDataBufferFactory().createInt(numBlocks+1, true, Nd4j.getMemoryManager().getCurrentWorkspace()); PointerPointer extras = new PointerPointer(32).put(1, context.getOldStream()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java index a1f1b39be..826bb0797 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java @@ -16,8 +16,7 @@ package org.nd4j.linalg.jcublas.context; -import lombok.Data; -import lombok.val; +import lombok.*; import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -44,49 +43,32 @@ import java.util.concurrent.atomic.AtomicBoolean; * */ @Data +@AllArgsConstructor +@NoArgsConstructor +@Builder public class CudaContext { - //private CUcontext context; - //private CUstream stream; - //private CUevent cUevent; + + // execution stream private cudaStream_t oldStream; - private cudaStream_t solverStream; - + // memcpy stream private cudaStream_t specialStream; - //private cudaEvent_t oldEvent; - private cublasHandle_t handle; + // exactly what it says + private cublasHandle_t cublasHandle; private cusolverDnHandle_t solverHandle; - private CublasPointer resultPointer; - private AtomicBoolean oldStreamReturned = new AtomicBoolean(false); - private AtomicBoolean handleReturned = new AtomicBoolean(false); - private AtomicBoolean streamReturned = new AtomicBoolean(false); - private boolean streamFromPool = true; - private boolean handleFromPool = true; - private boolean oldStreamFromPool = true; - private boolean free = true; - private boolean oldEventDestroyed = true; - private boolean eventDestroyed = true; + // temporary buffers, exactly 1 per thread private Pointer bufferReduction; private Pointer bufferAllocation; private Pointer bufferScalar; + + // legacy. to be removed. private Pointer bufferSpecial; - private GarbageResourceReference reference; private int deviceId = -1; - private long threadId; - - private int laneId = 0; - - private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - - - public CudaContext(boolean free) { - this(); - this.free = free; - } + private transient final static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); @Override public String toString() { @@ -94,34 +76,16 @@ public class CudaContext { "bufferReduction=" + bufferReduction + ", bufferScalar=" + bufferScalar + ", deviceId=" + deviceId + - ", threadId=" + threadId + - ", laneId=" + laneId + '}'; } - public void attachReference(GarbageResourceReference ref) { - reference = ref; - } - - - public CudaContext() { - // ContextHolder.getInstance().setContext(); - } - - /** - * Synchronizes on the new - * stream - */ - public void syncStream() { - //JCudaDriver.cuStreamSynchronize(stream); - } - /** * Synchronizes * on the old stream */ public void syncOldStream() { - syncOldStream(false); + if (nativeOps.streamSynchronize(oldStream) == 0) + throw new ND4JIllegalStateException("CUDA stream synchronization failed"); } public void syncSpecialStream() { @@ -129,125 +93,21 @@ public class CudaContext { throw new ND4JIllegalStateException("CUDA special stream synchronization failed"); } - public void syncOldStream(boolean syncCuBlas) { - // ContextHolder.getInstance().setContext(); - if (nativeOps.streamSynchronize(oldStream) == 0) - throw new ND4JIllegalStateException("CUDA stream synchronization failed"); - } - public Pointer getCublasStream() { + // FIXME: can we cache this please val lptr = new PointerPointer(this.getOldStream()); return lptr.get(0); } - - public void syncSolverStream() { - if (solverStream != null) { - if (nativeOps.streamSynchronize(solverStream) == 0) - throw new ND4JIllegalStateException("CUDA stream synchronization failed"); - } else - throw new IllegalStateException("cuBLAS stream isnt set"); + public cublasHandle_t getCublasHandle() { + // FIXME: can we cache this please + val lptr = new PointerPointer(cublasHandle); + return new cublasHandle_t(lptr.get(0)); } - /** - * Associates - * the handle on this context - * to the given stream - */ - public synchronized void associateHandle() { - //JCublas2.cublasSetStream(handle,oldStream); + public cusolverDnHandle_t getSolverHandle() { + // FIXME: can we cache this please + val lptr = new PointerPointer(solverHandle); + return new cusolverDnHandle_t(lptr.get(0)); } - - - - /** - * Initializes the stream - */ - public void initStream() { - // ContextHolder.getInstance().setContext(); - /* - if(stream == null) { - stream = new CUstream(); - JCudaDriver.cuStreamCreate(stream, CUstream_flags.CU_STREAM_DEFAULT); - streamFromPool = false; - eventDestroyed = false; - } - */ - } - - /** - * Initializes the old stream - */ - public void initOldStream() { - // ContextHolder.getInstance().setContext(); - if (oldStream == null) { - oldStreamFromPool = false; - oldStream = new cudaStream_t(nativeOps.createStream()); - //JCuda.cudaStreamCreate(oldStream); - - specialStream = new cudaStream_t(nativeOps.createStream()); - //JCuda.cudaStreamCreate(specialStream); - } - - } - - - - /** - * Initializes a handle and - * associates with the given stream. - * initOldStream() should be called first - * - */ - public void initHandle() { - /* - - We don't create handles here anymore - - if(handle == null) { - handle = new cublasHandle(); - JCublas2.cublasCreate(handle); - handleFromPool = false; - } - */ - } - - /** - * Destroys the context - * and associated resources - */ - @Deprecated - public void destroy(CublasPointer resultPointer, boolean freeIfNotEqual) {} - - - /** - * Destroys the context - * and associated resources - */ - @Deprecated - public void destroy() { - - } - - - /** - * Finishes a blas operation - * and destroys this context - */ - public void finishBlasOperation() { - //destroy(); - } - - /** - * Sets up a context with an old stream - * and a blas handle - * @return the cuda context - * as setup for cublas usage - */ - public static CudaContext getBlasContext() { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - //context.syncOldStream(false); - return context; - } - } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 3d1f66ef8..26f228430 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -60,6 +60,7 @@ import org.nd4j.linalg.cache.TADManager; import org.nd4j.linalg.compression.ThresholdCompression; import org.nd4j.linalg.exception.ND4JIllegalArgumentException; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.buffer.AddressRetriever; import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer; @@ -1513,11 +1514,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { val surfaceBuffer = (BaseCudaDataBuffer) getBuffer(batch); surfaceBuffer.lazyAllocateHostPointer(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); - IntPointer pointer = (IntPointer) new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)) + val pointer = (IntPointer) new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)) .asIntPointer(); - AllocationPoint surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer); + val surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer); int maxTypes = 5; @@ -1659,7 +1660,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { this.exec(single); } - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); context.syncOldStream(); } @@ -1671,9 +1672,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { int numIntArrays = op.getIntArrayArguments().size(); int numRealArguments = op.getRealArguments().size(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - PointerPointer extraArgs = new PointerPointer(32); + val extraArgs = new PointerPointer(32); extraArgs.put(0, null); extraArgs.put(1, context.getOldStream()); extraArgs.put(2, new CudaPointer(1)); @@ -1890,8 +1891,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void commit() { - ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncOldStream(); - ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncSpecialStream(); + AtomicAllocator.getInstance().getDeviceContext().syncOldStream(); + AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream(); } @Override @@ -1901,14 +1902,14 @@ public class CudaExecutioner extends DefaultOpExecutioner { int numThreads = 1024; int numBlocks = (int) (buffer.length() / numThreads + (buffer.length() % numThreads == 0 ? 0 : 1)); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); DataBuffer blocksBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(numBlocks+1, true) : Nd4j.getDataBufferFactory().createInt(numBlocks+1, true, Nd4j.getMemoryManager().getCurrentWorkspace()); if (extraz.get() == null) extraz.set(new PointerPointer(32)); - PointerPointer extras = extraz.get().put(1, context.getOldStream()); + val extras = extraz.get().put(1, context.getOldStream()); @@ -2024,7 +2025,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { DataBuffer result = target.data(); - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = AtomicAllocator.getInstance().getDeviceContext(); if (extraz.get() == null) extraz.set(new PointerPointer(32)); @@ -2254,10 +2255,11 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); val name = op.opName(); try (val context = (CudaOpContext) buildContext()) { + context.markInplace(op.isInplaceCall()); // transferring rng state @@ -2279,6 +2281,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); return result; + } catch (ND4JOpProfilerException e) { + throw e; } catch (Exception e) { throw new RuntimeException("Op [" + name + "] execution failed", e); } @@ -2545,11 +2549,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray[] exec(CustomOp op, OpContext context) { - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + long st = profilingConfigurableHookIn(op); + + val ctx = AtomicAllocator.getInstance().getDeviceContext(); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer()); + profilingConfigurableHookOut(op, st); + if (context.getOutputArrays().isEmpty()) return new INDArray[0]; else @@ -2559,7 +2567,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArrayStatistics inspectArray(@NonNull INDArray array) { val debugInfo = new Nd4jCuda.DebugInfo(); - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); AtomicAllocator.getInstance().synchronizeHostData(array); if (extraz.get() == null) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java index b2c86bf3a..ca8e4eb07 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java @@ -164,9 +164,9 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio } protected boolean compareDevicePointers(INDArray array, Op op) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); - Pointer pointer = AtomicAllocator.getInstance().getPointer(array, context); + val pointer = AtomicAllocator.getInstance().getPointer(array, context); long opZ = AtomicAllocator.getInstance().getPointer(op.z(), context).address(); long opX = AtomicAllocator.getInstance().getPointer(op.x(), context).address(); @@ -193,7 +193,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio protected boolean compareHostPointers(INDArray array, Op op) { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext(); Pointer pointer = AtomicAllocator.getInstance().getPointer(array, context); @@ -506,9 +506,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio AtomicAllocator allocator = AtomicAllocator.getInstance(); - // CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); - // FIXME: do not leave it as is - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = allocator.getDeviceContext(); pointers.setX(allocator.getPointer(op.x(), context)); pointers.setXShapeInfo(allocator.getPointer(op.x().shapeInfoDataBuffer(), context)); @@ -930,7 +928,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio public void flushQueueBlocking() { flushQueue(); - val context =((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()); + val context = AtomicAllocator.getInstance().getDeviceContext(); context.syncSpecialStream(); context.syncOldStream(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 8db04257b..749f5cc96 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -84,7 +84,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { // FIXME: remove Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setInputArray(index, array); @@ -94,7 +94,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { public void setOutputArray(int index, @NonNull INDArray array) { Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); - val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + val ctx = AtomicAllocator.getInstance().getDeviceContext(); nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setOutputArray(index, array); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 5b26a81ea..b15e4455e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3104,6 +3104,15 @@ public native void deleteRandomGenerator(OpaqueRandomGenerator ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); +public native OpaqueLaunchContext defaultLaunchContext(); +public native @Cast("Nd4jPointer") Pointer lcScalarPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcReductionPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcAllocationPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcExecutionStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); + // #endif //NATIVEOPERATIONS_NATIVEOPS_H @@ -9928,6 +9937,8 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include +// #include @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 6b9979ec2..51b9ce7e4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -121,6 +121,7 @@ public class Nd4jCudaPresets implements InfoMapper { .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) + .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java index 7c21fc86f..c19adf4ad 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/AllocatorTest.java @@ -22,7 +22,6 @@ import org.apache.commons.lang3.RandomUtils; import org.bytedeco.javacpp.Pointer; import org.junit.Ignore; import org.junit.Test; -import org.nd4j.jita.allocator.context.impl.LimitedContextPool; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.MemoryTracker; @@ -539,15 +538,6 @@ public class AllocatorTest { assertEquals(currEventsNumber+5, controller.getEventsProvider().getEventsNumber()); } - @Test - public void testReleaseContext() { - LimitedContextPool pool = (LimitedContextPool) AtomicAllocator.getInstance().getContextPool(); - System.out.println(pool.acquireContextForDevice(0)); - INDArray x = Nd4j.rand(1,10); - pool.releaseContext(pool.getContextForDevice(0)); - System.out.println(pool.getContextForDevice(0)); - } - @Test public void testDataBuffers() { INDArray x = Nd4j.create(DataType.FLOAT, 10, 5); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java index cecfe07d0..9584d5692 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/test/java/org/nd4j/jita/allocator/DeviceLocalNDArrayTests.java @@ -38,14 +38,16 @@ public class DeviceLocalNDArrayTests { val dl = new DeviceLocalNDArray(arr); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; val t = new Thread(new Runnable() { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(f); dl.get().addi(1.0); Nd4j.getExecutioner().commit(); } }); - Nd4j.getAffinityManager().attachThreadToDevice(t, e); + t.start(); t.join(); } @@ -60,9 +62,11 @@ public class DeviceLocalNDArrayTests { val dl = new DeviceLocalNDArray(arr); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; val t = new Thread(new Runnable() { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(f); for (int i = 0; i < 10; i++) { val tmp = Nd4j.create(DataType.DOUBLE, shape); tmp.addi(1.0); @@ -70,7 +74,7 @@ public class DeviceLocalNDArrayTests { } } }); - Nd4j.getAffinityManager().attachThreadToDevice(t, e); + t.start(); t.join(); @@ -79,14 +83,16 @@ public class DeviceLocalNDArrayTests { System.gc(); for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); e++) { + val f = e; val t = new Thread(new Runnable() { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(f); dl.get().addi(1.0); Nd4j.getExecutioner().commit(); } }); - Nd4j.getAffinityManager().attachThreadToDevice(t, e); + t.start(); t.join(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 28c0b12b3..2b47103c3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.ops.custom.Flatten; +import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayType; @@ -572,6 +573,10 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { if (toConcat.length == 1) return toConcat[0]; + return Nd4j.exec(new Concat(dimension, toConcat))[0]; + + // legacy implementation +/* // if reusable var wasn't created for this thread, or is smaller then needed - set it to new value if (extrazA.get() == null || extrazB.get() == null || extrazSize.get() == null || extrazSize.get() < toConcat.length) { extrazA.set(new PointerPointer(toConcat.length)); @@ -627,6 +632,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { null, null); return ret; + */ } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 8c0bebfc1..52fe5c652 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3104,6 +3104,15 @@ public native void deleteRandomGenerator(OpaqueRandomGenerator ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); +public native OpaqueLaunchContext defaultLaunchContext(); +public native @Cast("Nd4jPointer") Pointer lcScalarPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcReductionPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcAllocationPointer(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcExecutionStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcCopyStream(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); +public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); + // #endif //NATIVEOPERATIONS_NATIVEOPS_H @@ -9723,10 +9732,13 @@ public static final int PREALLOC_SIZE = 33554432; // #ifndef __CLION_IDE__ // #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLEU, NAME, (SIGNATURE), TYPES)) // #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SINGLE_T(RANDOMSINGLE, NAME, (SIGNATURE), TYPES)) +// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) EVAL(_EXEC_SELECTOR_T(TEMPLATE_SINGLE_TWICE, NAME, SIGNATURE, TYPES)) // #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) EVAL(_EXEC_DOUBLE_T(RANDOMDOUBLE, NAME, (SIGNATURE), (TYPES_A), TYPES_B)) // #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} +// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_TWICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} // #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_SINGLE_THRICE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} + // #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) switch(XTYPE) { EVAL(_EXEC_SELECTOR_T(SELECTOR_PARTIAL_SINGLE, NAME, SIGNATURE, TYPES)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); }} // #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TT_1(SELECTOR_DOUBLE, YTYPE, NAME, (SIGNATURE), (TYPES_B), TYPES_A)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type");}} // #define BUILD_TRIPLE_SELECTOR(XTYPE, YTYPE, ZTYPE, NAME, SIGNATURE, TYPES_X, TYPES_Y, TYPES_Z) switch(XTYPE) { EVAL(_EXEC_SELECTOR_TTT_1(SELECTOR_TRIPLE, YTYPE, ZTYPE, NAME, SIGNATURE, (TYPES_Z), (TYPES_Y), TYPES_X)); default: {printf("[ERROR] Unknown dtypeX=%d on %s:%d", XTYPE, __FILE__, __LINE__); fflush(stdout); throw std::runtime_error("bad data type"); } } @@ -9736,8 +9748,10 @@ public static final int PREALLOC_SIZE = 33554432; // #else // #define BUILD_SINGLE_UNCHAINED_TEMPLATE(NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_TEMPLATE(NAME, SIGNATURE, TYPES) +// #define BUILD_SINGLE_TEMPLATE_TWICE(NAME, SIGNATURE, TYPES) // #define BUILD_DOUBLE_TEMPLATE(NAME, SIGNATURE, TYPES_A, TYPES_B) // #define BUILD_SINGLE_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) +// #define BUILD_SINGLE_SELECTOR_TWICE(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_SELECTOR_THRICE(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_SINGLE_PARTIAL_SELECTOR(XTYPE, NAME, SIGNATURE, TYPES) // #define BUILD_DOUBLE_SELECTOR(XTYPE, YTYPE, NAME, SIGNATURE, TYPES_A, TYPES_B) @@ -9773,6 +9787,12 @@ public static final int PREALLOC_SIZE = 33554432; // #define _SELECTOR_SINGLE_THRICE(A, B, C, D) case C: {AB; break;}; // #define SELECTOR_SINGLE_THRICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_THRICE(A, B, UNPAREN(C))) +// #define _SELECTOR_SINGLE_TWICE(A, B, C, D) case C: {AB; break;}; +// #define SELECTOR_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_SINGLE_TWICE(A, B, UNPAREN(C))) + +// #define _TEMPLATE_SINGLE_TWICE(A, B, C, D) AB; +// #define TEMPLATE_SINGLE_TWICE(A, B, C) EVALUATING_PASTE(_TEM, PLATE_SINGLE_TWICE(A, B, UNPAREN(C))) + // #define _SELECTOR_PARTIAL_SINGLE(A, B, C, D) case C: {A D, UNPAREN2(B); break;}; // #define SELECTOR_PARTIAL_SINGLE(A, B, C) EVALUATING_PASTE(_SEL, ECTOR_PARTIAL_SINGLE(A, B, UNPAREN(C))) @@ -9801,6 +9821,7 @@ public static final int PREALLOC_SIZE = 33554432; // #define BROADCAST_BOOL(NAME) nd4j::BroadcastBoolOpsTuple::custom(nd4j::scalar::NAME, nd4j::pairwise::NAME, nd4j::broadcast::NAME) +public static final int ALL_INDICES =INT64; public static final int ALL_INTS =UINT64; public static final int ALL_FLOATS =BFLOAT16; @@ -22694,6 +22715,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include +// #include @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 5ad008055..dd47eb25d 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -164,6 +164,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) + .put(new Info("OpaqueLaunchContext").pointerTypes("OpaqueLaunchContext")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index d37ddb889..51780fb2c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -7150,7 +7150,11 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.getRandom().setSeed(12345); INDArray a = Nd4j.rand(2,5); INDArray b = Nd4j.rand(5,3); - INDArray exp = a.mmul(b).transpose(); + INDArray exp = a.mmul(b); + Nd4j.getExecutioner().commit(); + + exp = exp.transpose(); + INDArray act = a.mmul(b, MMulTranspose.builder().transposeResult(true).build()); assertEquals(exp, act); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index dfcf5dc79..f46d5e694 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -1073,11 +1073,13 @@ public class OpExecutionerTestsC extends BaseNd4jTest { List arrays = new ArrayList<>(); val num = 10; for (int i = 0; i < num; i++) { - arrays.add(Nd4j.create(20, 20).assign(i)); + arrays.add(Nd4j.create(5, 20).assign(i)); } INDArray pile = Nd4j.pile(arrays); + log.info("Pile: {}", pile); + INDArray[] tears = Nd4j.tear(pile, 1, 2); for (int i = 0; i < num; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 8a67bd2c2..d0c61de9b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -444,6 +444,8 @@ public class OperationProfilerTests extends BaseNd4jTest { Nd4j.exec(op); //Should trigger NaN panic fail(); } catch (Exception e){ + //throw new RuntimeException(e); + log.info("Message: {}", e.getMessage()); assertTrue(e.getMessage(), e.getMessage().contains("NaN")); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java index 0fecaa6fe..596bf16a7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; +import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.INDArrayIndex; @@ -212,7 +213,7 @@ public class ConcatTestsC extends BaseNd4jTest { assertEquals(exp, concat2); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = ND4JIllegalStateException.class) public void testConcatVector() { System.out.println(Nd4j.concat(0, Nd4j.ones(1,1000000), Nd4j.create(1, 1))); }