parent
f05c6ee139
commit
48df1acdfb
|
@ -35,6 +35,7 @@ import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
||||||
import org.nd4j.linalg.factory.NDArrayFactory;
|
import org.nd4j.linalg.factory.NDArrayFactory;
|
||||||
|
@ -482,23 +483,12 @@ public class ConvolutionUtils {
|
||||||
return reshape5dTo2d(format, mask, workspaceMgr, type);
|
return reshape5dTo2d(format, mask, workspaceMgr, type);
|
||||||
} else {
|
} else {
|
||||||
//Need to broadcast first
|
//Need to broadcast first
|
||||||
IntArrayList broadcastDims = new IntArrayList();
|
|
||||||
for(int i=0; i<mask.rank(); i++ ){
|
|
||||||
if(mask.size(i) == label.size(i)){
|
|
||||||
if((format == Convolution3D.DataFormat.NCDHW && i == 1) || (format == Convolution3D.DataFormat.NDHWC && i == 4)){
|
|
||||||
//Skip channels dimension
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
broadcastDims.add(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
long[] lShape = label.shape().clone();
|
long[] lShape = label.shape().clone();
|
||||||
int channelIdx = format == Convolution3D.DataFormat.NCDHW ? 1 : 4;
|
int channelIdx = format == Convolution3D.DataFormat.NCDHW ? 1 : 4;
|
||||||
lShape[channelIdx] = mask.size(channelIdx); //Keep existing channel size
|
lShape[channelIdx] = mask.size(channelIdx); //Keep existing channel size
|
||||||
|
|
||||||
INDArray bMask = workspaceMgr.createUninitialized(type, mask.dataType(), lShape, 'c');
|
INDArray bMask = workspaceMgr.createUninitialized(type, mask.dataType(), lShape, 'c');
|
||||||
int[] bcDims = broadcastDims.toIntArray();
|
Nd4j.exec(new Assign(new INDArray[]{bMask, mask}, new INDArray[]{bMask}));
|
||||||
Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, bcDims));
|
|
||||||
return reshape5dTo2d(format, bMask, workspaceMgr, type);
|
return reshape5dTo2d(format, bMask, workspaceMgr, type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,17 +16,20 @@ endif()
|
||||||
|
|
||||||
# -fsanitize=address
|
# -fsanitize=address
|
||||||
# -fsanitize=leak
|
# -fsanitize=leak
|
||||||
if (APPLE)
|
if (ANDROID_BUILD)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D_RELEASE=true")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true")
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else")
|
||||||
|
elseif (APPLE)
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
|
||||||
|
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true")
|
||||||
elseif(WIN32)
|
elseif(WIN32)
|
||||||
set(X86_BUILD true)
|
set(X86_BUILD true)
|
||||||
if (NOT CUDA_BLAS)
|
if (CUDA_BLAS)
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -DINLINE_LOOPS -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE " /O2 -D_RELEASE=true /wd4804")
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -DINLINE_LOOPS -fmax-errors=2")
|
|
||||||
else()
|
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
|
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
|
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
|
||||||
|
else()
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
|
||||||
|
set(CMAKE_CXX_FLAGS_DEBUG " -g -O2 -fPIC -std=c++11 -fmax-errors=2")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
|
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
|
||||||
|
@ -75,6 +78,9 @@ if(NOT CUDA_BLAS)
|
||||||
|
|
||||||
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
|
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
|
||||||
add_definitions(-D__EXTERNAL_BLAS__=true)
|
add_definitions(-D__EXTERNAL_BLAS__=true)
|
||||||
|
elseif(WIN32)
|
||||||
|
message("BLAS not found, using downloaded OpenBLAS instead")
|
||||||
|
add_definitions(-D__EXTERNAL_BLAS__=true)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
# if we have externally provided OPENBLAS_PATH - let's use it
|
# if we have externally provided OPENBLAS_PATH - let's use it
|
||||||
|
|
|
@ -5,7 +5,7 @@ project(mkldnn-download NONE)
|
||||||
include(ExternalProject)
|
include(ExternalProject)
|
||||||
ExternalProject_Add(mkldnn
|
ExternalProject_Add(mkldnn
|
||||||
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
GIT_REPOSITORY https://github.com/intel/mkl-dnn.git
|
||||||
GIT_TAG v1.0.2
|
GIT_TAG v1.0.4
|
||||||
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src"
|
||||||
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build"
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
|
|
|
@ -30,8 +30,8 @@ if(APPLE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (APPLE_BUILD)
|
if (APPLE_BUILD)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAPPLE_BUILD=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAPPLE_BUILD=true")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ANDROID_BUILD)
|
if (ANDROID_BUILD)
|
||||||
|
@ -92,11 +92,13 @@ ELSE()
|
||||||
IF(${EXTENSION} MATCHES "avx512")
|
IF(${EXTENSION} MATCHES "avx512")
|
||||||
message("Building AVX512 binary...")
|
message("Building AVX512 binary...")
|
||||||
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
|
# we need to set flag here, that we can use hardware f16 conversion + tell that cpu features should be tracked
|
||||||
message("Current CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mmmx -msse -msse2 -msse3 -msse4.1 -msse4.2 -mavx -mavx2 -mfma -mf16c -mavx512f -mavx512vl -mavx512bw -mavx512dq -mavx512cd -mbmi -mbmi2 -mprefetchwt1 -mclflushopt -mxsavec -mxsaves -DSD_F16C=true -DF_AVX512=true")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}")
|
if (NOT WIN32)
|
||||||
|
# we don't want this definition for msvc
|
||||||
|
set(ARCH_TUNE "-march=${ARCH} -mtune=${ARCH_TYPE}")
|
||||||
|
endif()
|
||||||
ENDIF()
|
ENDIF()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
|
||||||
|
@ -109,7 +111,7 @@ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel")
|
||||||
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC")
|
||||||
# using Visual Studio C++
|
# using Visual Studio C++
|
||||||
|
|
||||||
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /w")
|
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc ${ARCH_TUNE}")
|
||||||
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
# using GCC
|
# using GCC
|
||||||
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}")
|
||||||
|
@ -283,8 +285,8 @@ if(CUDA_BLAS)
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
message("CUDA on Windows: enabling /EHsc")
|
message("CUDA on Windows: enabling /EHsc")
|
||||||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj")
|
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
|
||||||
SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj")
|
SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@ -322,7 +324,7 @@ elseif(CPU_BLAS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (X86_BUILD)
|
if (X86_BUILD)
|
||||||
#we disable platform optimizations for certains files
|
# we disable platform optimizations for certains files for linux/macos
|
||||||
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
endif()
|
endif()
|
||||||
|
@ -342,7 +344,16 @@ elseif(CPU_BLAS)
|
||||||
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
add_library(${LIBND4J_NAME} SHARED $<TARGET_OBJECTS:nd4jobj>)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#if(WIN32)
|
||||||
|
# message("CPU on Windows: enabling /EHsc")
|
||||||
|
# SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /bigobj /std:c++14")
|
||||||
|
# SET_TARGET_PROPERTIES(${LIBND4J_NAME} PROPERTIES COMPILER_FLAGS "/EHsc /bigobj /std:c++14")
|
||||||
|
#endif()
|
||||||
|
|
||||||
# we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS
|
# we're including {MKLDNN} here in case of building from sources. in future that'll replace {MKLDNN_LIBRARIES}. same applies to BLAS
|
||||||
|
if (NOT BLAS_LIBRARIES)
|
||||||
|
set(BLAS_LIBRARIES "")
|
||||||
|
endif()
|
||||||
target_link_libraries(${LIBND4J_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(${LIBND4J_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
|
|
||||||
if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
|
if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "Environment.h"
|
#include "Environment.h"
|
||||||
#include <helpers/StringUtils.h>
|
#include <helpers/StringUtils.h>
|
||||||
|
#include <thread>
|
||||||
|
#include <helpers/logger.h>
|
||||||
|
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
@ -49,6 +51,7 @@ namespace nd4j {
|
||||||
_precBoost.store(false);
|
_precBoost.store(false);
|
||||||
_leaks.store(false);
|
_leaks.store(false);
|
||||||
_dataType.store(nd4j::DataType::FLOAT32);
|
_dataType.store(nd4j::DataType::FLOAT32);
|
||||||
|
_maxThreads = std::thread::hardware_concurrency();
|
||||||
|
|
||||||
#ifndef ANDROID
|
#ifndef ANDROID
|
||||||
const char* omp_threads = std::getenv("OMP_NUM_THREADS");
|
const char* omp_threads = std::getenv("OMP_NUM_THREADS");
|
||||||
|
@ -86,9 +89,7 @@ namespace nd4j {
|
||||||
cudaSetDevice(0);
|
cudaSetDevice(0);
|
||||||
delete[] devProperties;
|
delete[] devProperties;
|
||||||
#else
|
#else
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include <indexing/IndicesList.h>
|
#include <indexing/IndicesList.h>
|
||||||
#include <graph/Intervals.h>
|
#include <graph/Intervals.h>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
|
#include <array/DataTypeUtils.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <array/ArrayOptions.h>
|
#include <array/ArrayOptions.h>
|
||||||
#include <array/ArrayType.h>
|
#include <array/ArrayType.h>
|
||||||
|
@ -1678,7 +1679,6 @@ namespace nd4j {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
size_t NDArray::sizeOfT() const {
|
size_t NDArray::sizeOfT() const {
|
||||||
|
|
||||||
return DataTypeUtils::sizeOfElement(_dataType);
|
return DataTypeUtils::sizeOfElement(_dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2478,7 +2478,6 @@ double NDArray::getTrace() const {
|
||||||
|
|
||||||
double sum = 0.;
|
double sum = 0.;
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
|
||||||
for(int i = 0; i < minDim; ++i)
|
for(int i = 0; i < minDim; ++i)
|
||||||
sum += e<double>(i * offset);
|
sum += e<double>(i * offset);
|
||||||
|
|
||||||
|
@ -3275,7 +3274,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
||||||
// regular numeric types
|
// regular numeric types
|
||||||
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
||||||
|
|
||||||
ExtraArguments extras({eps});
|
ExtraArguments extras({0.0, 0.0, eps});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
||||||
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
|
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
|
||||||
|
@ -3288,7 +3287,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
||||||
|
|
||||||
synchronize("NDArray::equalsTo");
|
synchronize("NDArray::equalsTo");
|
||||||
|
|
||||||
if (tmp.e<int>(0) > 0)
|
if (tmp.e<Nd4jLong>(0) != 0)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -24,10 +24,10 @@
|
||||||
|
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <loops/aggregates.h>
|
|
||||||
#include <ops/specials.h>
|
#include <ops/specials.h>
|
||||||
#include <ops/specials_sparse.h>
|
#include <ops/specials_sparse.h>
|
||||||
#include <execution/LaunchContext.h>
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <array/ArrayOptions.h>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Native op executioner:
|
* Native op executioner:
|
||||||
|
@ -624,10 +624,6 @@ static void execTransformBool(nd4j::LaunchContext *lc,
|
||||||
void *vrealArguments,
|
void *vrealArguments,
|
||||||
int numRealArguments) {
|
int numRealArguments) {
|
||||||
|
|
||||||
auto arguments = reinterpret_cast<X **>(varguments);
|
|
||||||
auto realArguments = reinterpret_cast<X *>(vrealArguments);
|
|
||||||
|
|
||||||
functions::aggregate::AggregatedFunction<X>::exec(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,6 @@
|
||||||
#define ND4J_EXPORT
|
#define ND4J_EXPORT
|
||||||
#endif
|
#endif
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <helpers/BlasHelper.h>
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
int tad_threshold = 1;
|
int tad_threshold = 1;
|
||||||
|
@ -1430,7 +1429,11 @@ static const char* getNpyArrayNameFromMap(void *map, int index){
|
||||||
for(; it != end; ++it, ++cnt){
|
for(; it != end; ++it, ++cnt){
|
||||||
if (cnt == index){
|
if (cnt == index){
|
||||||
// FIXME: @fariz, this is a leak!
|
// FIXME: @fariz, this is a leak!
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
return const_cast<const char *>(_strdup(it->first.c_str()));
|
||||||
|
#else
|
||||||
return const_cast<const char *>(strdup(it->first.c_str()));
|
return const_cast<const char *>(strdup(it->first.c_str()));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw std::runtime_error("No array at index.");
|
throw std::runtime_error("No array at index.");
|
||||||
|
|
|
@ -98,24 +98,27 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
|
||||||
|
|
||||||
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo());
|
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo());
|
||||||
|
|
||||||
std::vector<Nd4jLong> coords(zRank);
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
shape::index2coords(i, target->getShapeInfo(), coords);
|
||||||
|
const auto zOffset = shape::getOffset(target->getShapeInfo(), coords);
|
||||||
|
|
||||||
shape::index2coords(i, target->getShapeInfo(), coords.data());
|
// if( (row + upper < col) || (row + lower > col) )
|
||||||
const auto zOffset = shape::getOffset(target->getShapeInfo(), coords.data());
|
if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
|
||||||
|
z[zOffset] = value;
|
||||||
|
else if (this != target) { // when this and target are different arrays
|
||||||
|
if (xRank != zRank)
|
||||||
|
coords[0] = coords[1];
|
||||||
|
|
||||||
// if( (row + upper < col) || (row + lower > col) )
|
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords);
|
||||||
if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
|
z[zOffset] = x[xOffset];
|
||||||
z[zOffset] = value;
|
}
|
||||||
else if(this != target) { // when this and target are different arrays
|
|
||||||
if(xRank != zRank)
|
|
||||||
coords[0] = coords[1];
|
|
||||||
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords.data());
|
|
||||||
z[zOffset] = x[xOffset];
|
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -140,7 +143,7 @@ void NDArray::setIdentity() {
|
||||||
minDim = shape[i];
|
minDim = shape[i];
|
||||||
|
|
||||||
float v = 1.0f;
|
float v = 1.0f;
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
|
||||||
for(int i = 0; i < minDim; ++i)
|
for(int i = 0; i < minDim; ++i)
|
||||||
templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
|
templatedSet<float>(buffer(), i*offset, this->dataType(), &v);
|
||||||
}
|
}
|
||||||
|
@ -151,12 +154,15 @@ static void templatedSwap(void *xBuffer, void *yBuffer, Nd4jLong length) {
|
||||||
auto x = reinterpret_cast<T *>(xBuffer);
|
auto x = reinterpret_cast<T *>(xBuffer);
|
||||||
auto y = reinterpret_cast<T *>(yBuffer);
|
auto y = reinterpret_cast<T *>(yBuffer);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(schedule(static))
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong i = 0; i < length; ++i) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto temp = x[i];
|
auto temp = x[i];
|
||||||
x[i] = y[i];
|
x[i] = y[i];
|
||||||
y[i] = temp;
|
y[i] = temp;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void templatedSwap, (void *xBuffer, void *yBuffer, Nd4jLong length), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -262,21 +268,26 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
auto xType = this->dataType();
|
auto xType = this->dataType();
|
||||||
if(result.ordering() == 'c') { // ews == 1 always here
|
if(result.ordering() == 'c') { // ews == 1 always here
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for(Nd4jLong i = 0; i < resultLen; ++i) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
||||||
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), i, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
samediff::Threads::parallel_for(func, 0, resultLen);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for(Nd4jLong i=0; i<resultLen; ++i) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto xOffset = result.getOffset(i);
|
auto xOffset = result.getOffset(i);
|
||||||
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
auto yOffset = shape::subArrayOffset(i, newShapeInfo, getShapeInfo());
|
||||||
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign, (result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, this->template templatedAssign,(result.getBuffer(), xOffset, this->getBuffer(), yOffset), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, resultLen);
|
||||||
}
|
}
|
||||||
result.tickWriteHost();
|
result.tickWriteHost();
|
||||||
return result;
|
return result;
|
||||||
|
@ -337,14 +348,7 @@ void NDArray::tile(NDArray& target) const {
|
||||||
// looping through _buffer goes automatically by means of getSubArrayIndex applying
|
// looping through _buffer goes automatically by means of getSubArrayIndex applying
|
||||||
const auto ews = target.ews();
|
const auto ews = target.ews();
|
||||||
const auto targetLen = target.lengthOf();
|
const auto targetLen = target.lengthOf();
|
||||||
if(target.ordering() == 'c' && ews == 1) { // ews == 1 always here
|
if(target.ordering() == 'c' && ews >= 1) {
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < targetLen; ++i) {
|
|
||||||
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
|
|
||||||
BUILD_DOUBLE_SELECTOR(target.dataType(), dataType(), templatedDoubleAssign, (target.getBuffer(), i, getBuffer(), yOffset), LIBND4J_TYPES, LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if(target.ordering() == 'c' && ews > 1) {
|
|
||||||
|
|
||||||
for(Nd4jLong i=0; i<targetLen; ++i) {
|
for(Nd4jLong i=0; i<targetLen; ++i) {
|
||||||
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
|
auto yOffset = shape::subArrayOffset(i, target.getShapeInfo(), getShapeInfo());
|
||||||
|
@ -373,30 +377,30 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
|
||||||
const int zLen = output.lengthOf(); // xLen <= zLen
|
const int zLen = output.lengthOf(); // xLen <= zLen
|
||||||
const int repSize = repeats.size();
|
const int repSize = repeats.size();
|
||||||
|
|
||||||
std::vector<Nd4jLong> coords(rank);
|
|
||||||
|
|
||||||
// loop through input array
|
// loop through input array
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(coords))
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
shape::index2coords(i, output.getShapeInfo(), coords);
|
||||||
|
|
||||||
shape::index2coords(i, output.getShapeInfo(), coords.data());
|
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords);
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(output.getShapeInfo(), coords.data());
|
if (repSize > 1) {
|
||||||
|
for (uint j = 0; j < repSize; ++j) {
|
||||||
if(repSize > 1) {
|
coords[axis] -= repeats[j];
|
||||||
for (uint j = 0; j < repSize; ++j) {
|
if (coords[axis] < 0) {
|
||||||
coords[axis] -= repeats[j];
|
coords[axis] = j;
|
||||||
if (coords[axis] < 0) {
|
break;
|
||||||
coords[axis] = j;
|
}
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
} else
|
||||||
}
|
coords[axis] /= repeats[0];
|
||||||
else
|
|
||||||
coords[axis] /= repeats[0];
|
|
||||||
|
|
||||||
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords.data())];
|
z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords)];
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -32,33 +32,40 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
|
||||||
|
|
||||||
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
|
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
z[e] = func(f[e], s[e], t[e]);
|
z[e] = func(f[e], s[e], t[e]);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto tOffset = this->getOffset(e);
|
||||||
|
auto uOffset = second->getOffset(e);
|
||||||
|
auto vOffset = third->getOffset(e);
|
||||||
|
|
||||||
auto tOffset = this->getOffset(e);
|
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||||
auto uOffset = second->getOffset(e);
|
}
|
||||||
auto vOffset = third->getOffset(e);
|
};
|
||||||
|
|
||||||
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto tOffset = this->getOffset(e);
|
||||||
|
auto uOffset = second->getOffset(e);
|
||||||
|
auto vOffset = third->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
auto tOffset = this->getOffset(e);
|
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||||
auto uOffset = second->getOffset(e);
|
}
|
||||||
auto vOffset = third->getOffset(e);
|
};
|
||||||
auto zOffset = target->getOffset(e);
|
|
||||||
|
|
||||||
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -103,31 +110,38 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
|
||||||
|
|
||||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
z[e] = func(f[e], s[e]);
|
z[e] = func(f[e], s[e]);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
f[xOffset] = func(f[xOffset], s[yOffset]);
|
||||||
auto yOffset = other->getOffset(e);
|
}
|
||||||
|
};
|
||||||
|
|
||||||
f[xOffset] = func(f[xOffset], s[yOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
z[zOffset] = func(f[xOffset], s[yOffset]);
|
||||||
auto yOffset = other->getOffset(e);
|
}
|
||||||
auto zOffset = target->getOffset(e);
|
};
|
||||||
|
|
||||||
z[zOffset] = func(f[xOffset], s[yOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -161,29 +175,36 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
|
||||||
|
|
||||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (int e = 0; e < _length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
z[e] = func(f[e]);
|
z[e] = func(f[e]);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (int e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
f[xOffset] = func(f[xOffset]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
f[xOffset] = func(f[xOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (int e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
z[zOffset] = func(f[xOffset]);
|
||||||
auto zOffset = target->getOffset(e);
|
}
|
||||||
|
};
|
||||||
|
|
||||||
z[zOffset] = func(f[xOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -217,29 +238,36 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
|
||||||
|
|
||||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
z[e] = func(e, f[e]);
|
z[e] = func(e, f[e]);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
f[xOffset] = func(e, f[xOffset]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
f[xOffset] = func(e, f[xOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
z[zOffset] = func(e, f[xOffset]);
|
||||||
auto zOffset = target->getOffset(e);
|
}
|
||||||
|
};
|
||||||
|
|
||||||
z[zOffset] = func(e, f[xOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -282,31 +310,38 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
|
||||||
|
|
||||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < _length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
} else {
|
} else {
|
||||||
if (f == z) {
|
if (f == z) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (int e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||||
auto yOffset = other->getOffset(e);
|
}
|
||||||
|
};
|
||||||
|
|
||||||
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto loop = PRAGMA_THREADS_FOR {
|
||||||
for (int e = 0; e < _length; e++) {
|
for (auto e = start; e < stop; e += increment) {
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||||
auto yOffset = other->getOffset(e);
|
}
|
||||||
auto zOffset = target->getOffset(e);
|
};
|
||||||
|
|
||||||
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
samediff::Threads::parallel_for(loop, 0, _length);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
#include "NativeOpExecutioner.h"
|
#include "NativeOpExecutioner.h"
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
|
|
||||||
|
#include <LoopKind.h>
|
||||||
|
|
||||||
#include <pairwise_bool.h>
|
#include <pairwise_bool.h>
|
||||||
#include <broadcasting_bool.h>
|
#include <broadcasting_bool.h>
|
||||||
#include <scalar_bool.h>
|
#include <scalar_bool.h>
|
||||||
|
@ -50,11 +52,14 @@
|
||||||
#include <loops/random.h>
|
#include <loops/random.h>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
|
#include <array/TadPack.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
|
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -78,9 +83,7 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, int op
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -111,9 +114,7 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -149,9 +150,7 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
|
@ -160,7 +159,16 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
auto numTads = xLen / yLen;
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,9 +187,7 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
||||||
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
|
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
|
||||||
|
@ -190,7 +196,15 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
auto numTads = yLen / xLen;
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -208,15 +222,21 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
auto numTads = xLen / yLen;
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
|
@ -231,9 +251,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
|
@ -243,7 +261,15 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
if (yType != xType || nd4j::DataType::BOOL != zType)
|
if (yType != xType || nd4j::DataType::BOOL != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
auto numTads = yLen / xLen;
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -260,9 +286,7 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
|
@ -274,7 +298,15 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
auto numTads = xLen / yLen;
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
|
@ -289,21 +321,27 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType);
|
||||||
|
|
||||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt requires integer data type", zType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt,::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
auto numTads = yLen / xLen;
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -328,9 +366,7 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
|
@ -339,7 +375,15 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform,
|
||||||
|
::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop),
|
||||||
|
LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -353,9 +397,7 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
|
@ -367,7 +409,13 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
|
||||||
if (zType != nd4j::DataType::BOOL)
|
if (zType != nd4j::DataType::BOOL)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -380,9 +428,7 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
|
@ -394,7 +440,13 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
|
||||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), INTEGER_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, start, stop), INTEGER_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen, 1, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -417,14 +469,22 @@ void NativeOpExecutioner::execReduceFloat(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
// nothing to do here if result is empty
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -437,14 +497,22 @@ void NativeOpExecutioner::execReduceSame(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
// nothing to do here if result is empty
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -457,14 +525,22 @@ void NativeOpExecutioner::execReduceBool(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES);
|
// nothing to do here if result is empty
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -477,14 +553,22 @@ void NativeOpExecutioner::execReduceLong(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES);
|
// nothing to do here if result is empty
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, LONG_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == nd4j::LoopKind::Kind::SMALLARR2DX ? 1 : nd4j::Environment::getInstance()->maxThreads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -503,9 +587,7 @@ void NativeOpExecutioner::execReduceFloatScalar(nd4j::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -521,9 +603,7 @@ void NativeOpExecutioner::execReduceSameScalar(nd4j::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
|
|
||||||
|
@ -539,9 +619,7 @@ void NativeOpExecutioner::execReduceBoolScalar(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -557,9 +635,7 @@ void NativeOpExecutioner::execReduceLongScalar(nd4j::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -591,10 +667,6 @@ void NativeOpExecutioner::execReduce3Scalar(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
@ -623,15 +695,13 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, nullptr, 1), LIBND4J_TYPES, FLOAT_TYPES);
|
//BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, nullptr, 0), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
NativeOpExecutioner::execReduce3Scalar(lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -647,14 +717,31 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
|
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength), LIBND4J_TYPES, FLOAT_TYPES);
|
const auto xLen = shape::length(hXShapeInfo);
|
||||||
|
const auto yLen = shape::length(hYShapeInfo);
|
||||||
|
|
||||||
|
nd4j::TadPack tadPack;
|
||||||
|
|
||||||
|
if(xLen == yLen) {
|
||||||
|
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||||
|
}
|
||||||
|
else if(yLen > xLen) {
|
||||||
|
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -671,15 +758,19 @@ void NativeOpExecutioner::execReduce3All(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
||||||
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||||
// BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
|
||||||
|
// TODO: make it 2d
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -696,15 +787,31 @@ void NativeOpExecutioner::execReduce3TAD(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) {
|
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
const auto xLen = shape::length(hXShapeInfo);
|
||||||
// BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
const auto yLen = shape::length(hYShapeInfo);
|
||||||
|
|
||||||
|
nd4j::TadPack tadPack;
|
||||||
|
|
||||||
|
if(xLen == yLen) {
|
||||||
|
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||||
|
}
|
||||||
|
else if(yLen > xLen) {
|
||||||
|
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hYShapeInfo, dimension, dimensionLength);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(opNum, hX, hXShapeInfo, extraParamsVals, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, tadPack.numberOfTads());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -729,9 +836,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
||||||
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
||||||
void *extraParams, bool allowParallelism) {
|
void *extraParams, bool allowParallelism) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
|
@ -743,7 +848,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, allowParallelism), LIBND4J_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform,::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -760,9 +871,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
|
@ -774,7 +883,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto yLen = shape::length(hScalarShapeInfo);
|
||||||
|
samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min<int>(yLen, nd4j::Environment::getInstance()->maxThreads()));
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -789,9 +904,7 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *extraParams, bool allowParallelism) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
||||||
|
@ -803,7 +916,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
if (zType != nd4j::DataType::BOOL)
|
if (zType != nd4j::DataType::BOOL)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, BOOL_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -819,9 +938,7 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
|
@ -833,7 +950,12 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
if (zType != nd4j::DataType::BOOL)
|
if (zType != nd4j::DataType::BOOL)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", nd4j::DataType::BOOL, zType);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto yLen = shape::length(hScalarShapeInfo);
|
||||||
|
samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min<int>(yLen, nd4j::Environment::getInstance()->maxThreads()));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -847,9 +969,7 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *extraParams, bool allowParallelism) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
||||||
|
@ -861,7 +981,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), INTEGER_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, start, stop), INTEGER_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto zLen = shape::length(hZShapeInfo);
|
||||||
|
samediff::Threads::parallel_for(func, 0, zLen, 1, !allowParallelism ? 1 : nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(zLen / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -877,9 +1003,7 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
|
@ -891,7 +1015,12 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ, start, stop), INTEGER_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto yLen = shape::length(hScalarShapeInfo);
|
||||||
|
samediff::Threads::parallel_tad(func, 0, yLen, 1, nd4j::math::nd4j_min<int>(yLen, nd4j::Environment::getInstance()->maxThreads()));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -912,9 +1041,7 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
bool biasCorrected) {
|
bool biasCorrected) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -940,9 +1067,7 @@ void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
bool biasCorrected) {
|
bool biasCorrected) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -972,10 +1097,6 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
bool biasCorrected) {
|
bool biasCorrected) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
@ -1002,14 +1123,14 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
auto func = PRAGMA_THREADS_DO {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1021,14 +1142,14 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES);
|
auto func = PRAGMA_THREADS_DO {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1040,14 +1161,14 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets, allowParallelism), LIBND4J_TYPES, LIBND4J_TYPES);
|
auto func = PRAGMA_THREADS_DO {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1059,14 +1180,14 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
auto func = PRAGMA_THREADS_DO {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1078,14 +1199,14 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), FLOAT_TYPES);
|
auto func = PRAGMA_THREADS_DO {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_do(func, nd4j::math::nd4j_max<int>(1, nd4j::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, nd4j::Environment::getInstance()->maxThreads())));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1095,9 +1216,7 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraArguments) {
|
void *extraArguments) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
@ -1116,9 +1235,7 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraArguments) {
|
void *extraArguments) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
@ -1139,9 +1256,7 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraArguments) {
|
void *extraArguments) {
|
||||||
#ifdef _OPENMP
|
|
||||||
omp_set_nested(1);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,6 @@
|
||||||
#include <templatemath.h>
|
#include <templatemath.h>
|
||||||
#include <types/float8.h>
|
#include <types/float8.h>
|
||||||
#include <loops/type_conversions.h>
|
#include <loops/type_conversions.h>
|
||||||
#include <loops/aggregates.h>
|
|
||||||
#include <helpers/helper_ptrmap.h>
|
#include <helpers/helper_ptrmap.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
|
@ -36,6 +35,7 @@
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <ops/declarable/helpers/transforms.h>
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
#include <exceptions/allocation_exception.h>
|
#include <exceptions/allocation_exception.h>
|
||||||
|
#include <helpers/BlasHelper.h>
|
||||||
|
|
||||||
|
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
|
@ -75,6 +75,7 @@ bool experimentalSupport = false;
|
||||||
#include <performance/benchmarking/BenchmarkSuit.h>
|
#include <performance/benchmarking/BenchmarkSuit.h>
|
||||||
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
||||||
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
#ifdef CPU_FEATURES
|
#ifdef CPU_FEATURES
|
||||||
#include <cpuinfo_x86.h>
|
#include <cpuinfo_x86.h>
|
||||||
|
@ -1152,10 +1153,7 @@ void initializeFunctions(Nd4jPointer *functions) {
|
||||||
* @param flags optional parameter
|
* @param flags optional parameter
|
||||||
*/
|
*/
|
||||||
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
||||||
Nd4jPointer pointer = (Nd4jPointer) malloc(memorySize);
|
return reinterpret_cast<Nd4jPointer>(new int8_t[memorySize]);
|
||||||
if (pointer == 0)
|
|
||||||
return 0L;
|
|
||||||
return pointer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1179,7 +1177,7 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
||||||
* @param pointer pointer that'll be freed
|
* @param pointer pointer that'll be freed
|
||||||
*/
|
*/
|
||||||
int freeHost(Nd4jPointer pointer) {
|
int freeHost(Nd4jPointer pointer) {
|
||||||
free(reinterpret_cast<void *>(pointer));
|
delete[] reinterpret_cast<int8_t *>(pointer);
|
||||||
return 1L;
|
return 1L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1364,37 +1362,37 @@ void pullRowsGeneric(void *vx,
|
||||||
|
|
||||||
int elementsPerThread = n / TAD_THRESHOLD;
|
int elementsPerThread = n / TAD_THRESHOLD;
|
||||||
int _threads = nd4j::math::nd4j_max<int>(1, elementsPerThread);
|
int _threads = nd4j::math::nd4j_max<int>(1, elementsPerThread);
|
||||||
_threads = nd4j::math::nd4j_min<int>(_threads, omp_get_max_threads());
|
_threads = nd4j::math::nd4j_min<int>(_threads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(_threads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (int idx = 0; idx < n; idx++) {
|
for (auto idx = start; idx < stop; idx += increment) {
|
||||||
auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
|
auto xTadOffsetForBlock = tadOffsets[indexes[idx]];
|
||||||
auto zTadOffsetForBlock = zTadOffsets[idx];
|
auto zTadOffsetForBlock = zTadOffsets[idx];
|
||||||
|
|
||||||
auto rX = hX + xTadOffsetForBlock;
|
auto rX = hX + xTadOffsetForBlock;
|
||||||
auto rZ = hZ + zTadOffsetForBlock;
|
auto rZ = hZ + zTadOffsetForBlock;
|
||||||
|
|
||||||
if (xEWS == 1 && zEWS == 1) {
|
if (xEWS == 1 && zEWS == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_SIMD
|
for (int i = 0; i < tadLength; i++) {
|
||||||
for (int i = 0; i < tadLength; i++ ) {
|
rZ[i] = rX[i];
|
||||||
rZ[i] = rX[i];
|
}
|
||||||
}
|
} else if (xEWS >= 1 && zEWS >= 1) {
|
||||||
} else if (xEWS >= 1 && zEWS >= 1) {
|
PRAGMA_OMP_SIMD
|
||||||
|
for (int i = 0; i < tadLength; i++) {
|
||||||
PRAGMA_OMP_SIMD
|
rZ[i * zEWS] = rX[i * xEWS];
|
||||||
for (int i = 0; i < tadLength; i++ ) {
|
}
|
||||||
rZ[i * zEWS] = rX[i * xEWS];
|
} else {
|
||||||
|
for (int i = 0; i < tadLength; i++) {
|
||||||
|
auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo);
|
||||||
|
auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo);
|
||||||
|
hZ[zOffset] = hX[xOffset];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
};
|
||||||
for (int i = 0; i < tadLength; i++) {
|
|
||||||
auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo);
|
samediff::Threads::parallel_tad(func, 0, n, 1, _threads);
|
||||||
auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo);
|
|
||||||
hZ[zOffset] = hX[xOffset];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void pullRows(Nd4jPointer *extraPointers,
|
void pullRows(Nd4jPointer *extraPointers,
|
||||||
|
@ -1433,30 +1431,29 @@ void tearGeneric(void *vx,
|
||||||
auto zEWS = shape::elementWiseStride(hZShapeInfo);
|
auto zEWS = shape::elementWiseStride(hZShapeInfo);
|
||||||
auto numTads = shape::length(hXShapeInfo) / tadLength;
|
auto numTads = shape::length(hXShapeInfo) / tadLength;
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong i = 0; i < numTads; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto hZ = reinterpret_cast<T *>(targets[i]);
|
auto hZ = reinterpret_cast<T *>(targets[i]);
|
||||||
auto s = hX + tadOffsets[i];
|
auto s = hX + tadOffsets[i];
|
||||||
|
|
||||||
if (zEWS == 1 && tadEWS == 1) {
|
if (zEWS == 1 && tadEWS == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_SIMD
|
for (Nd4jLong j = 0; j < tadLength; j++) {
|
||||||
for (Nd4jLong j = 0; j < tadLength; j++) {
|
hZ[j] = s[j];
|
||||||
hZ[j] = s[j];
|
}
|
||||||
}
|
} else if (zEWS > 0 && tadEWS > 0) {
|
||||||
} else if (zEWS > 0 && tadEWS > 0) {
|
PRAGMA_OMP_SIMD
|
||||||
|
for (Nd4jLong j = 0; j < tadLength; j++) {
|
||||||
PRAGMA_OMP_SIMD
|
hZ[j * zEWS] = s[j * tadEWS];
|
||||||
for (Nd4jLong j = 0; j < tadLength; j++) {
|
}
|
||||||
hZ[j * zEWS] = s[j * tadEWS];
|
} else {
|
||||||
|
for (Nd4jLong j = 0; j < tadLength; j++)
|
||||||
|
hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
};
|
||||||
|
|
||||||
for (Nd4jLong j = 0; j < tadLength; j++)
|
samediff::Threads::parallel_tad(func,0, numTads);
|
||||||
hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void tear(Nd4jPointer *extraPointers,
|
void tear(Nd4jPointer *extraPointers,
|
||||||
|
@ -1557,57 +1554,60 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
|
||||||
auto dX = reinterpret_cast<T **>(hX);
|
auto dX = reinterpret_cast<T **>(hX);
|
||||||
auto dZ = reinterpret_cast<T **>(dz);
|
auto dZ = reinterpret_cast<T **>(dz);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(N)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (int f = 0; f < N; f++) {
|
for (auto f = start; f < stop; f += increment) {
|
||||||
auto hX = reinterpret_cast<T *>(dX[f]);
|
auto hX = reinterpret_cast<T *>(dX[f]);
|
||||||
//auto hZ = reinterpret_cast<T *>(dZ[f]);
|
//auto hZ = reinterpret_cast<T *>(dZ[f]);
|
||||||
|
|
||||||
auto xShapeInfo = hXShapeInfo[f];
|
auto xShapeInfo = hXShapeInfo[f];
|
||||||
auto tadOffset = reinterpret_cast<Nd4jLong *>(tadOffsets[f]);
|
auto tadOffset = reinterpret_cast<Nd4jLong *>(tadOffsets[f]);
|
||||||
|
|
||||||
|
|
||||||
const auto tadLength = shape::length(tadOnlyShapeInfo[f]);
|
const auto tadLength = shape::length(tadOnlyShapeInfo[f]);
|
||||||
auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]);
|
auto tadEWS = shape::elementWiseStride(tadOnlyShapeInfo[f]);
|
||||||
auto tadRank = shape::rank(tadOnlyShapeInfo[f]);
|
auto tadRank = shape::rank(tadOnlyShapeInfo[f]);
|
||||||
auto numTads = shape::length(hXShapeInfo[f]) / tadLength;
|
auto numTads = shape::length(hXShapeInfo[f]) / tadLength;
|
||||||
|
|
||||||
auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]);
|
auto tadShape = shape::shapeOf(tadOnlyShapeInfo[f]);
|
||||||
auto tadStride = shape::stride(tadOnlyShapeInfo[f]);
|
auto tadStride = shape::stride(tadOnlyShapeInfo[f]);
|
||||||
|
|
||||||
if (shape::rank(xShapeInfo) == 1) {
|
if (shape::rank(xShapeInfo) == 1) {
|
||||||
auto xLength = shape::length(xShapeInfo);
|
auto xLength = shape::length(xShapeInfo);
|
||||||
auto ews = shape::elementWiseStride(xShapeInfo);
|
auto ews = shape::elementWiseStride(xShapeInfo);
|
||||||
for (Nd4jLong r = 0; r < xLength; r++) {
|
for (Nd4jLong r = 0; r < xLength; r++) {
|
||||||
auto swapIdx = shuffleMap[r];
|
auto swapIdx = shuffleMap[r];
|
||||||
if (swapIdx < 0)
|
if (swapIdx < 0)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
nd4j::math::nd4j_swap<T>(hX[r*ews], hX[swapIdx*ews]);
|
nd4j::math::nd4j_swap<T>(hX[r * ews], hX[swapIdx * ews]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (Nd4jLong r = 0; r < numTads; r++) {
|
for (Nd4jLong r = 0; r < numTads; r++) {
|
||||||
if (shuffleMap[r] < 0)
|
if (shuffleMap[r] < 0)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto oldOffset = tadOffset[r];
|
auto oldOffset = tadOffset[r];
|
||||||
auto newOffset = tadOffset[shuffleMap[r]];
|
auto newOffset = tadOffset[shuffleMap[r]];
|
||||||
|
|
||||||
auto rX = hX + oldOffset;
|
auto rX = hX + oldOffset;
|
||||||
auto rY = hX + newOffset;
|
auto rY = hX + newOffset;
|
||||||
|
|
||||||
if (tadEWS == 1) {
|
if (tadEWS == 1) {
|
||||||
for (Nd4jLong i = 0; i < tadLength; i++) {
|
for (Nd4jLong i = 0; i < tadLength; i++) {
|
||||||
nd4j::math::nd4j_swap<T>(rX[i], rY[i]);
|
nd4j::math::nd4j_swap<T>(rX[i], rY[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (Nd4jLong i = 0; i < tadLength; i++) {
|
for (Nd4jLong i = 0; i < tadLength; i++) {
|
||||||
auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]);
|
auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]);
|
||||||
nd4j::math::nd4j_swap<T>(hX[offset + oldOffset], hX[offset + newOffset]);
|
nd4j::math::nd4j_swap<T>(hX[offset + oldOffset], hX[offset + newOffset]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, N);
|
||||||
}
|
}
|
||||||
|
|
||||||
void shuffle(Nd4jPointer *extras,
|
void shuffle(Nd4jPointer *extras,
|
||||||
|
@ -1772,72 +1772,9 @@ void execAggregate(Nd4jPointer *extraPointers,int opNum,
|
||||||
void *realArguments,
|
void *realArguments,
|
||||||
int numRealArguments,
|
int numRealArguments,
|
||||||
nd4j::DataType dtype) {
|
nd4j::DataType dtype) {
|
||||||
try {
|
|
||||||
BUILD_SINGLE_SELECTOR(dtype, NativeOpExecutioner::execAggregate, (nullptr, opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), FLOAT_TYPES);
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void _batchExecutor(Nd4jPointer *extraPointers,
|
|
||||||
int numAggregates,
|
|
||||||
int opNum,
|
|
||||||
int maxArgs,
|
|
||||||
int maxShapes,
|
|
||||||
int maxIntArrays,
|
|
||||||
int maxIntArraySize,
|
|
||||||
int maxIdx,
|
|
||||||
int maxReals,
|
|
||||||
void *ptrToArguments,
|
|
||||||
nd4j::DataType dtype) {
|
|
||||||
// probably, we don't want too much threads as usually
|
|
||||||
int _threads = nd4j::math::nd4j_min<int>(numAggregates, omp_get_max_threads());
|
|
||||||
|
|
||||||
nd4j::PointersHelper<T> helper(ptrToArguments,
|
|
||||||
numAggregates,
|
|
||||||
maxArgs,
|
|
||||||
maxShapes,
|
|
||||||
maxIntArrays,
|
|
||||||
maxIntArraySize,
|
|
||||||
maxIdx,
|
|
||||||
maxReals);
|
|
||||||
|
|
||||||
// special case here, we prefer spread arrangement here, all threads are detached from each other
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(_threads)
|
|
||||||
for (int i = 0; i < numAggregates; i++) {
|
|
||||||
auto intArrays = new int *[maxIntArrays];
|
|
||||||
|
|
||||||
auto arguments = helper.getArguments(i);
|
|
||||||
auto shapes = helper.getShapeArguments(i);
|
|
||||||
auto idxArg = helper.getIndexArguments(i);
|
|
||||||
auto realArg = helper.getRealArguments(i);
|
|
||||||
|
|
||||||
for (int e = 0; e < maxIntArrays; e++) {
|
|
||||||
intArrays[e] = helper.getIntArrayArguments(i, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
execAggregate(extraPointers,
|
|
||||||
opNum,
|
|
||||||
reinterpret_cast<void **>(arguments),
|
|
||||||
helper.getNumArguments(i),
|
|
||||||
shapes,
|
|
||||||
helper.getNumShapeArguments(i),
|
|
||||||
idxArg,
|
|
||||||
helper.getNumIndexArguments(i),
|
|
||||||
intArrays,
|
|
||||||
helper.getNumIntArrayArguments(i),
|
|
||||||
realArg,
|
|
||||||
helper.getNumRealArguments(i),
|
|
||||||
dtype);
|
|
||||||
|
|
||||||
delete [] intArrays;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _batchExecutor, (Nd4jPointer *extraPointers, int numAggregates, int opNum, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments, nd4j::DataType dtype), FLOAT_TYPES);
|
|
||||||
|
|
||||||
void batchExecutor(Nd4jPointer *extraPointers,
|
void batchExecutor(Nd4jPointer *extraPointers,
|
||||||
int numAggregates,
|
int numAggregates,
|
||||||
int opNum,
|
int opNum,
|
||||||
|
@ -1849,12 +1786,7 @@ void batchExecutor(Nd4jPointer *extraPointers,
|
||||||
int maxReals,
|
int maxReals,
|
||||||
void *ptrToArguments,
|
void *ptrToArguments,
|
||||||
nd4j::DataType dtype) {
|
nd4j::DataType dtype) {
|
||||||
try {
|
|
||||||
BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES);
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void execAggregateBatch(Nd4jPointer *extraPointers,
|
void execAggregateBatch(Nd4jPointer *extraPointers,
|
||||||
|
@ -1868,12 +1800,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
|
||||||
int maxReals,
|
int maxReals,
|
||||||
void *ptrToArguments,
|
void *ptrToArguments,
|
||||||
nd4j::DataType dtype) {
|
nd4j::DataType dtype) {
|
||||||
try {
|
|
||||||
BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES);
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2094,27 +2021,21 @@ const char* getAllCustomOps() {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer hX, int N, T threshold) {
|
FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer hX, int N, T threshold) {
|
||||||
auto buffer = reinterpret_cast<T *>(hX);
|
auto buffer = reinterpret_cast<T *>(hX);
|
||||||
|
|
||||||
int span = (N / 6) + 8;
|
int span = (N / 6) + 8;
|
||||||
int cnt = 0;
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_REDUCTION(+:cnt)
|
|
||||||
{
|
|
||||||
int tid = omp_get_thread_num();
|
|
||||||
int start = span * tid;
|
|
||||||
int stop = span * (tid + 1);
|
|
||||||
if (stop > N)
|
|
||||||
stop = N;
|
|
||||||
|
|
||||||
|
auto func = PRAGMA_REDUCE_LONG {
|
||||||
|
int64_t cnt = 0;
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int e = start; e < stop; e++) {
|
for (auto e = start; e < stop; e++) {
|
||||||
auto v = nd4j::math::nd4j_abs<T>(buffer[e]);
|
auto v = nd4j::math::nd4j_abs<T>(buffer[e]);
|
||||||
if (v >= threshold)
|
if (v >= threshold)
|
||||||
cnt++;
|
cnt++;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return cnt;
|
return cnt;
|
||||||
|
};
|
||||||
|
|
||||||
|
return samediff::Threads::parallel_long(func, LAMBDA_AL { return _old + _new; }, 0, N);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2776,58 +2697,51 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
|
||||||
void* vIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) {
|
void* vIindexes, Nd4jLong* hIndicesShapeInfo, void* dIindexes, Nd4jLong* dIndicesShapeInfo) {
|
||||||
|
|
||||||
auto hIindexes = reinterpret_cast<I*>(vIindexes);
|
auto hIindexes = reinterpret_cast<I*>(vIindexes);
|
||||||
|
auto func = PRAGMA_THREADS_DO {
|
||||||
int numThreads = omp_get_max_threads();
|
for (int i = 0; i < numOfSubArrs; ++i) {
|
||||||
|
int threadIndex = thread_id;
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(numThreads)
|
|
||||||
{
|
|
||||||
for (int i = 0; i < numOfSubArrs; ++i) {
|
|
||||||
|
|
||||||
int threadIndex = omp_get_thread_num();
|
|
||||||
const auto xIndex = hIindexes[i];
|
const auto xIndex = hIindexes[i];
|
||||||
const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads;
|
const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads;
|
||||||
|
|
||||||
if (!isOwner)
|
if (!isOwner)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
NDArray inSubArr(
|
NDArray inSubArr(reinterpret_cast<int8_t *>(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), hXShapeInfo);
|
||||||
reinterpret_cast<int8_t *>(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)),
|
NDArray updSubArr(reinterpret_cast<int8_t *>(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), hYShapeInfo);
|
||||||
hXShapeInfo);
|
|
||||||
NDArray updSubArr(reinterpret_cast<int8_t *>(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)),
|
|
||||||
hYShapeInfo);
|
|
||||||
|
|
||||||
if (inSubArr.lengthOf() != updSubArr.lengthOf()) {
|
if (inSubArr.lengthOf() != updSubArr.lengthOf()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (opCode) {
|
switch (opCode) {
|
||||||
case 0:
|
case 0:
|
||||||
inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr);
|
inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr);
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr);
|
inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr);
|
inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr);
|
inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr);
|
inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr);
|
||||||
break;
|
break;
|
||||||
case 5:
|
case 5:
|
||||||
inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr);
|
inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr);
|
||||||
break;
|
break;
|
||||||
case 6:
|
case 6:
|
||||||
inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr);
|
inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
|
||||||
|
|
||||||
|
samediff::Threads::parallel_do(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2847,6 +2761,7 @@ void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) {
|
void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) {
|
||||||
try {
|
try {
|
||||||
auto p = reinterpret_cast<nd4j::DebugInfo *>(debugInfo);
|
auto p = reinterpret_cast<nd4j::DebugInfo *>(debugInfo);
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <loops/transform_any.h>
|
#include <loops/transform_any.h>
|
||||||
#include <loops/reduce_bool.h>
|
#include <loops/reduce_bool.h>
|
||||||
#include <loops/reduce_long.h>
|
#include <loops/reduce_long.h>
|
||||||
|
#include <loops/scalar.h>
|
||||||
#include <helpers/threshold.h>
|
#include <helpers/threshold.h>
|
||||||
#include <ops/specials_cuda.h>
|
#include <ops/specials_cuda.h>
|
||||||
#include <helpers/DebugHelper.h>
|
#include <helpers/DebugHelper.h>
|
||||||
|
@ -33,8 +34,8 @@
|
||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
#include <exceptions/cuda_exception.h>
|
#include <exceptions/cuda_exception.h>
|
||||||
#include <helpers/CudaLaunchHelper.h>
|
#include <helpers/CudaLaunchHelper.h>
|
||||||
// FIXME: we need cuda-specific implementations
|
|
||||||
#include <GraphExecutioner.h>
|
#include <GraphExecutioner.h>
|
||||||
|
#include <helpers/BlasHelper.h>
|
||||||
#include <graph/GraphHolder.h>
|
#include <graph/GraphHolder.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <PointersManager.h>
|
#include <PointersManager.h>
|
||||||
|
@ -1723,11 +1724,7 @@ void execScalarTad(Nd4jPointer *extraPointers,
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform,
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
||||||
::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ,
|
|
||||||
dZShapeInfo, dScalars, extraParams, dimension,
|
|
||||||
dimensionLength, tadShapeInfo, tadOffsets,
|
|
||||||
tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
|
@ -1750,23 +1747,7 @@ void execAggregate(Nd4jPointer *extraPointers,
|
||||||
void *realArguments,
|
void *realArguments,
|
||||||
int numRealArguments,
|
int numRealArguments,
|
||||||
nd4j::DataType dtype) {
|
nd4j::DataType dtype) {
|
||||||
try {
|
|
||||||
cudaStream_t *stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
|
||||||
int numBlocks = getDeviceId(extraPointers[2]);
|
|
||||||
int numThreads = getDeviceId(extraPointers[3]);
|
|
||||||
int shmem = getDeviceId(extraPointers[4]);
|
|
||||||
|
|
||||||
dim3 launchDims = dim3(numBlocks, numThreads, shmem);
|
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction,
|
|
||||||
::aggregateKernelGeneric(launchDims, stream, opNum, arguments, numArguments, shapes,
|
|
||||||
numShapes, indexArguments, numIndexArguments, intArrays,
|
|
||||||
numIntArrays, realArguments, numRealArguments), FLOAT_TYPES);
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed");
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void batchExecutor(Nd4jPointer *extraPointers,
|
void batchExecutor(Nd4jPointer *extraPointers,
|
||||||
|
@ -1788,25 +1769,7 @@ void execAggregateBatch(Nd4jPointer *extraPointers,
|
||||||
int maxIntArrays, int maxIntArraySize,
|
int maxIntArrays, int maxIntArraySize,
|
||||||
int maxIdx, int maxReals,
|
int maxIdx, int maxReals,
|
||||||
void *ptrToArguments, nd4j::DataType dtype) {
|
void *ptrToArguments, nd4j::DataType dtype) {
|
||||||
try {
|
|
||||||
// not implemented yet
|
|
||||||
cudaStream_t *stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
|
||||||
int numBlocks = getDeviceId(extraPointers[2]);
|
|
||||||
int numThreads = getDeviceId(extraPointers[3]);
|
|
||||||
int shmem = getDeviceId(extraPointers[4]);
|
|
||||||
|
|
||||||
dim3 launchDims = dim3(numAggregates, numThreads, shmem);
|
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction,
|
|
||||||
::aggregateBatchKernelGeneric(launchDims, stream, opNum, numAggregates, maxArgs,
|
|
||||||
maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals,
|
|
||||||
ptrToArguments), FLOAT_TYPES);
|
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
|
||||||
} catch (std::exception &e) {
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -53,6 +53,7 @@ CLEAN="false"
|
||||||
MINIFIER="false"
|
MINIFIER="false"
|
||||||
TESTS="false"
|
TESTS="false"
|
||||||
VERBOSE="false"
|
VERBOSE="false"
|
||||||
|
VERBOSE_ARG="VERBOSE=1"
|
||||||
HELPER=
|
HELPER=
|
||||||
NAME=
|
NAME=
|
||||||
while [[ $# > 0 ]]
|
while [[ $# > 0 ]]
|
||||||
|
@ -291,38 +292,37 @@ case "$OS" in
|
||||||
|
|
||||||
macosx*)
|
macosx*)
|
||||||
# Do something under Mac OS X platform
|
# Do something under Mac OS X platform
|
||||||
if [ "$CHIP" == "cuda" ]; then
|
#if [ "$CHIP" == "cuda" ]; then
|
||||||
export CC=clang
|
export CC=clang
|
||||||
export CXX=clang++
|
export CXX=clang++
|
||||||
PARALLEL="false"
|
|
||||||
else
|
|
||||||
export CC="$(ls -1 /usr/local/bin/gcc-? | head -n 1)"
|
|
||||||
export CXX="$(ls -1 /usr/local/bin/g++-? | head -n 1)"
|
|
||||||
PARALLEL="true"
|
PARALLEL="true"
|
||||||
fi
|
#else
|
||||||
|
# export CC="$(ls -1 /usr/local/bin/gcc-? | head -n 1)"
|
||||||
|
# export CXX="$(ls -1 /usr/local/bin/g++-? | head -n 1)"
|
||||||
|
# PARALLEL="true"
|
||||||
|
#fi
|
||||||
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DAPPLE_BUILD=true"
|
export CMAKE_COMMAND="$CMAKE_COMMAND -DCMAKE_MACOSX_RPATH=ON -DAPPLE_BUILD=true"
|
||||||
;;
|
;;
|
||||||
|
|
||||||
windows*)
|
windows*)
|
||||||
# Do something under Windows NT platform
|
# Do something under Windows NT platform
|
||||||
if [ "$CHIP" == "cuda" ]; then
|
if [ "$CHIP" == "cuda" ]; then
|
||||||
export CMAKE_COMMAND="cmake -G \"Ninja\""
|
export CMAKE_COMMAND="cmake -G \"Ninja\""
|
||||||
export MAKE_COMMAND="ninja"
|
export MAKE_COMMAND="ninja"
|
||||||
export CC="cl.exe"
|
export CC="cl.exe"
|
||||||
export CXX="cl.exe"
|
export CXX="cl.exe"
|
||||||
PARALLEL="true"
|
PARALLEL="true"
|
||||||
else
|
VERBOSE_ARG="-v"
|
||||||
|
else
|
||||||
export CMAKE_COMMAND="cmake -G \"MSYS Makefiles\""
|
export CMAKE_COMMAND="cmake -G \"MSYS Makefiles\""
|
||||||
export MAKE_COMMAND="make"
|
export MAKE_COMMAND="make"
|
||||||
|
|
||||||
# Sam, do we really need this?
|
|
||||||
export CC=/mingw64/bin/gcc
|
export CC=/mingw64/bin/gcc
|
||||||
export CXX=/mingw64/bin/g++
|
export CXX=/mingw64/bin/g++
|
||||||
PARALLEL="true"
|
PARALLEL="true"
|
||||||
|
fi
|
||||||
|
|
||||||
fi
|
# Try some defaults for Visual Studio 2013 if user has not run vcvarsall.bat or something
|
||||||
# Try some defaults for Visual Studio 2013 if user has not run vcvarsall.bat or something
|
if [ -z "${VCINSTALLDIR:-}" ]; then
|
||||||
if [ -z "${VCINSTALLDIR:-}" ]; then
|
|
||||||
export VisualStudioVersion=12.0
|
export VisualStudioVersion=12.0
|
||||||
export VSINSTALLDIR="C:\\Program Files (x86)\\Microsoft Visual Studio $VisualStudioVersion"
|
export VSINSTALLDIR="C:\\Program Files (x86)\\Microsoft Visual Studio $VisualStudioVersion"
|
||||||
export VCINSTALLDIR="$VSINSTALLDIR\\VC"
|
export VCINSTALLDIR="$VSINSTALLDIR\\VC"
|
||||||
|
@ -332,10 +332,10 @@ case "$OS" in
|
||||||
export LIB="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\lib\\winv6.3\\um\\x64"
|
export LIB="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\lib\\winv6.3\\um\\x64"
|
||||||
export LIBPATH="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\References\\CommonConfiguration\\Neutral"
|
export LIBPATH="$VCINSTALLDIR\\LIB\\amd64;$WindowsSdkDir\\References\\CommonConfiguration\\Neutral"
|
||||||
export PATH="$PATH:$VCINSTALLDIR\\BIN\\amd64:$WindowsSdkDir\\bin\\x64:$WindowsSdkDir\\bin\\x86"
|
export PATH="$PATH:$VCINSTALLDIR\\BIN\\amd64:$WindowsSdkDir\\bin\\x64:$WindowsSdkDir\\bin\\x86"
|
||||||
fi
|
fi
|
||||||
# Make sure we are using 64-bit MinGW-w64
|
# Make sure we are using 64-bit MinGW-w64
|
||||||
export PATH=/mingw64/bin/:$PATH
|
export PATH=/mingw64/bin/:/mingw64/lib:$PATH
|
||||||
# export GENERATOR="MSYS Makefiles"
|
# export GENERATOR="MSYS Makefiles"
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
|
|
||||||
|
@ -534,6 +534,6 @@ if [ "$PARALLEL" == "true" ]; then
|
||||||
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ"
|
||||||
fi
|
fi
|
||||||
if [ "$VERBOSE" == "true" ]; then
|
if [ "$VERBOSE" == "true" ]; then
|
||||||
MAKE_ARGUMENTS="$MAKE_ARGUMENTS VERBOSE=1"
|
MAKE_ARGUMENTS="$MAKE_ARGUMENTS $VERBOSE_ARG"
|
||||||
fi
|
fi
|
||||||
eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../..
|
eval $MAKE_COMMAND $MAKE_ARGUMENTS && cd ../../..
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include <helpers/BitwiseUtils.h>
|
#include <helpers/BitwiseUtils.h>
|
||||||
#include <loops/type_conversions.h>
|
#include <loops/type_conversions.h>
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -50,9 +51,12 @@ namespace nd4j {
|
||||||
else
|
else
|
||||||
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
|
TypeCast::convertGeneric<T2, T>(nullptr, tmp, length, buffer);
|
||||||
#else
|
#else
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
|
@ -105,9 +109,12 @@ namespace nd4j {
|
||||||
else
|
else
|
||||||
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
|
TypeCast::convertGeneric<float, T>(nullptr, tmp, length, buffer);
|
||||||
#else
|
#else
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
|
@ -130,9 +137,12 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
#else
|
#else
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
}
|
}
|
||||||
|
@ -153,9 +163,12 @@ namespace nd4j {
|
||||||
else
|
else
|
||||||
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
|
TypeCast::convertGeneric<float16, T>(nullptr, tmp, length, buffer);
|
||||||
#else
|
#else
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong e = 0; e < length; e++)
|
for (auto e = start; e < stop; e += increment)
|
||||||
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
buffer[e] = canKeep ? static_cast<T>(tmp[e]) : BitwiseUtils::swap_bytes<T>(static_cast<T>(tmp[e]));
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length);
|
||||||
#endif
|
#endif
|
||||||
delete[] tmp;
|
delete[] tmp;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include <helpers/DebugHelper.h>
|
||||||
#endif
|
#endif
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
|
|
||||||
|
|
|
@ -97,10 +97,10 @@ namespace cnpy {
|
||||||
* @param t
|
* @param t
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
char mapType(const std::type_info &t);
|
ND4J_EXPORT char mapType(const std::type_info &t);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
char mapType();
|
ND4J_EXPORT char mapType();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -111,7 +111,7 @@ namespace cnpy {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
template<typename T>
|
template<typename T>
|
||||||
std::vector<char> createNpyHeader(const void *data,
|
ND4J_EXPORT std::vector<char> createNpyHeader(const void *data,
|
||||||
const unsigned int *shape,
|
const unsigned int *shape,
|
||||||
const unsigned int ndims,
|
const unsigned int ndims,
|
||||||
unsigned int wordSize = 4);
|
unsigned int wordSize = 4);
|
||||||
|
@ -126,7 +126,7 @@ namespace cnpy {
|
||||||
* @param ndims
|
* @param ndims
|
||||||
* @param fortranOrder
|
* @param fortranOrder
|
||||||
*/
|
*/
|
||||||
void parseNpyHeader(FILE *fp,
|
ND4J_EXPORT void parseNpyHeader(FILE *fp,
|
||||||
unsigned int &wordSize,
|
unsigned int &wordSize,
|
||||||
unsigned int *&shape,
|
unsigned int *&shape,
|
||||||
unsigned int &ndims,
|
unsigned int &ndims,
|
||||||
|
@ -143,7 +143,7 @@ namespace cnpy {
|
||||||
* @param ndims
|
* @param ndims
|
||||||
* @param fortran_order
|
* @param fortran_order
|
||||||
*/
|
*/
|
||||||
void parseNpyHeaderPointer(
|
ND4J_EXPORT void parseNpyHeaderPointer(
|
||||||
const char *header,
|
const char *header,
|
||||||
unsigned int& word_size,
|
unsigned int& word_size,
|
||||||
unsigned int*& shape,
|
unsigned int*& shape,
|
||||||
|
@ -156,7 +156,7 @@ namespace cnpy {
|
||||||
* @param global_header_size
|
* @param global_header_size
|
||||||
* @param global_header_offset
|
* @param global_header_offset
|
||||||
*/
|
*/
|
||||||
void parseZipFooter(FILE *fp,
|
ND4J_EXPORT void parseZipFooter(FILE *fp,
|
||||||
unsigned short &nrecs,
|
unsigned short &nrecs,
|
||||||
unsigned int &global_header_size,
|
unsigned int &global_header_size,
|
||||||
unsigned int &global_header_offset);
|
unsigned int &global_header_offset);
|
||||||
|
@ -167,14 +167,14 @@ namespace cnpy {
|
||||||
* @param varname
|
* @param varname
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
NpyArray npzLoad(std::string fname, std::string varname);
|
ND4J_EXPORT NpyArray npzLoad(std::string fname, std::string varname);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param fname
|
* @param fname
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
NpyArray npyLoad(std::string fname);
|
ND4J_EXPORT NpyArray npyLoad(std::string fname);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parse the numpy header from
|
* Parse the numpy header from
|
||||||
|
@ -187,7 +187,7 @@ namespace cnpy {
|
||||||
* @param ndims
|
* @param ndims
|
||||||
* @param fortranOrder
|
* @param fortranOrder
|
||||||
*/
|
*/
|
||||||
void parseNpyHeaderStr(std::string header,
|
ND4J_EXPORT void parseNpyHeaderStr(std::string header,
|
||||||
unsigned int &wordSize,
|
unsigned int &wordSize,
|
||||||
unsigned int *&shape,
|
unsigned int *&shape,
|
||||||
unsigned int &ndims,
|
unsigned int &ndims,
|
||||||
|
@ -199,14 +199,14 @@ namespace cnpy {
|
||||||
* @param fp
|
* @param fp
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
int * shapeFromFile(FILE *fp);
|
ND4J_EXPORT int* shapeFromFile(FILE *fp);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param data
|
* @param data
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
int * shapeFromPointer(char *data);
|
ND4J_EXPORT int* shapeFromPointer(char *data);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the numpy array from the given file.
|
* Load the numpy array from the given file.
|
||||||
|
@ -250,7 +250,7 @@ namespace cnpy {
|
||||||
* @param ndims
|
* @param ndims
|
||||||
* @param fortran_order
|
* @param fortran_order
|
||||||
*/
|
*/
|
||||||
void parseNpyHeader(std::string header,
|
ND4J_EXPORT void parseNpyHeader(std::string header,
|
||||||
unsigned int &word_size,
|
unsigned int &word_size,
|
||||||
unsigned int *&shape,
|
unsigned int *&shape,
|
||||||
unsigned int &ndims,
|
unsigned int &ndims,
|
||||||
|
@ -273,7 +273,7 @@ namespace cnpy {
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w");
|
ND4J_EXPORT void npy_save(std::string fname, const T* data, const unsigned int* shape, const unsigned int ndims, std::string mode = "w");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,8 +284,8 @@ namespace cnpy {
|
||||||
* @param rhs
|
* @param rhs
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
template<typename T>
|
template<typename T>
|
||||||
std::vector<char>& operator+=(std::vector<char>& lhs, const T rhs);
|
ND4J_EXPORT std::vector<char>& operator+=(std::vector<char>& lhs, const T rhs);
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
#ifndef NATIVEOPERATIONS_DLL_H
|
#ifndef NATIVEOPERATIONS_DLL_H
|
||||||
#define NATIVEOPERATIONS_DLL_H
|
#define NATIVEOPERATIONS_DLL_H
|
||||||
|
|
||||||
|
#include <msvc.h>
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
//#include <windows.h>
|
//#include <windows.h>
|
||||||
# define ND4J_EXPORT __declspec(dllexport)
|
# define ND4J_EXPORT __declspec(dllexport)
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_BLOCKINGQUEUE_H
|
||||||
|
#define SAMEDIFF_BLOCKINGQUEUE_H
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <queue>
|
||||||
|
#include <mutex>
|
||||||
|
#include <atomic>
|
||||||
|
#include <condition_variable>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
template <typename T>
|
||||||
|
class BlockingQueue {
|
||||||
|
private:
|
||||||
|
std::queue<T> _queue;
|
||||||
|
std::mutex _lock;
|
||||||
|
std::atomic<int> _size;
|
||||||
|
std::atomic<bool> _available;
|
||||||
|
|
||||||
|
std::condition_variable _condition;
|
||||||
|
public:
|
||||||
|
BlockingQueue(int queueSize);
|
||||||
|
~BlockingQueue() = default;
|
||||||
|
T poll();
|
||||||
|
void put(const T &t);
|
||||||
|
|
||||||
|
bool available();
|
||||||
|
void markAvailable();
|
||||||
|
void markUnavailable();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_BLOCKINGQUEUE_H
|
|
@ -0,0 +1,94 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_CALLABLEINTERFACE_H
|
||||||
|
#define SAMEDIFF_CALLABLEINTERFACE_H
|
||||||
|
|
||||||
|
#include <openmp_pragmas.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <atomic>
|
||||||
|
#include <array>
|
||||||
|
#include <mutex>
|
||||||
|
#include <condition_variable>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
/**
|
||||||
|
* This class is suited for passing functions to execution threads without queues
|
||||||
|
*/
|
||||||
|
class CallableInterface {
|
||||||
|
private:
|
||||||
|
// parallel_for functions
|
||||||
|
FUNC_1D _function_1d;
|
||||||
|
FUNC_2D _function_2d;
|
||||||
|
FUNC_3D _function_3d;
|
||||||
|
|
||||||
|
// parallel function
|
||||||
|
FUNC_DO _function_do;
|
||||||
|
|
||||||
|
// reduction functions
|
||||||
|
FUNC_RL _function_rl;
|
||||||
|
FUNC_RD _function_rd;
|
||||||
|
|
||||||
|
std::array<int64_t, 9> _arguments;
|
||||||
|
|
||||||
|
volatile int _branch = 0;
|
||||||
|
volatile uint32_t _thread_id = 0;
|
||||||
|
volatile uint32_t _num_threads = 0;
|
||||||
|
|
||||||
|
std::atomic<bool> _finished;
|
||||||
|
std::atomic<bool> _filled;
|
||||||
|
std::atomic<bool> _available;
|
||||||
|
|
||||||
|
std::condition_variable _starter;
|
||||||
|
std::condition_variable _finisher;
|
||||||
|
|
||||||
|
int64_t* _lptr = nullptr;
|
||||||
|
double* _dptr = nullptr;
|
||||||
|
|
||||||
|
std::mutex _ms;
|
||||||
|
std::mutex _mf;
|
||||||
|
public:
|
||||||
|
CallableInterface();
|
||||||
|
~CallableInterface() = default;
|
||||||
|
|
||||||
|
void waitForTask();
|
||||||
|
void waitForCompletion();
|
||||||
|
|
||||||
|
void fill(int thread_id, int num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
void fill(int thread_id, int num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
|
||||||
|
void fill(int thread_id, int num_threads, FUNC_DO func);
|
||||||
|
void fill(int thread_id, int num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
void fill(int thread_id, int num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
|
||||||
|
void fill(int thread_id, int num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z);
|
||||||
|
|
||||||
|
bool available();
|
||||||
|
void markAvailable();
|
||||||
|
void markUnavailable();
|
||||||
|
|
||||||
|
void finish();
|
||||||
|
|
||||||
|
void execute();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_CALLABLEINTERFACE_H
|
|
@ -0,0 +1,92 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef DEV_TESTS_CALLABLEWITHARGUMENTS_H
|
||||||
|
#define DEV_TESTS_CALLABLEWITHARGUMENTS_H
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
#include <atomic>
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
class CallableWithArguments {
|
||||||
|
FUNC_DO _function_do;
|
||||||
|
FUNC_1D _function_1d;
|
||||||
|
FUNC_2D _function_2d;
|
||||||
|
FUNC_3D _function_3d;
|
||||||
|
|
||||||
|
std::vector<int64_t> _arguments;
|
||||||
|
|
||||||
|
std::atomic<bool> _finished;
|
||||||
|
|
||||||
|
std::condition_variable _condition;
|
||||||
|
|
||||||
|
std::mutex _lock;
|
||||||
|
|
||||||
|
int _dimensions = 0;
|
||||||
|
|
||||||
|
uint64_t _threadId;
|
||||||
|
uint64_t _numThreads;
|
||||||
|
public:
|
||||||
|
CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads);
|
||||||
|
CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x);
|
||||||
|
CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y);
|
||||||
|
CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns number of dimensions
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
int dimensions();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method checks if this callable is finished
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
bool finished();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* this method marks this Callable as finished
|
||||||
|
*/
|
||||||
|
void finish();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method blocks until callable is finished
|
||||||
|
*/
|
||||||
|
void waitUntilFinished();
|
||||||
|
|
||||||
|
std::vector<int64_t>& arguments();
|
||||||
|
FUNC_DO function_do();
|
||||||
|
FUNC_1D function_1d();
|
||||||
|
FUNC_2D function_2d();
|
||||||
|
FUNC_3D function_3d();
|
||||||
|
|
||||||
|
|
||||||
|
uint64_t threadId();
|
||||||
|
|
||||||
|
uint64_t numThreads();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_CALLABLEWITHARGUMENTS_H
|
|
@ -0,0 +1,71 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_THREADPOOL_H
|
||||||
|
#define SAMEDIFF_THREADPOOL_H
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <vector>
|
||||||
|
#include <thread>
|
||||||
|
#include <atomic>
|
||||||
|
#include <mutex>
|
||||||
|
#include <execution/BlockingQueue.h>
|
||||||
|
#include <execution/CallableWithArguments.h>
|
||||||
|
#include <execution/CallableInterface.h>
|
||||||
|
#include <execution/Ticket.h>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
class ThreadPool {
|
||||||
|
private:
|
||||||
|
static ThreadPool* _INSTANCE;
|
||||||
|
|
||||||
|
std::vector<std::thread*> _threads;
|
||||||
|
std::vector<BlockingQueue<CallableWithArguments*>*> _queues;
|
||||||
|
std::vector<CallableInterface*> _interfaces;
|
||||||
|
|
||||||
|
std::mutex _lock;
|
||||||
|
std::atomic<int> _available;
|
||||||
|
std::queue<Ticket*> _tickets;
|
||||||
|
protected:
|
||||||
|
ThreadPool();
|
||||||
|
~ThreadPool();
|
||||||
|
public:
|
||||||
|
static ThreadPool* getInstance();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns list of pointers to threads ONLY if num_threads of threads were available upon request, returning empty list otherwise
|
||||||
|
* @param num_threads
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
Ticket* tryAcquire(int num_threads);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method marks specified number of threads as released, and available for use
|
||||||
|
* @param num_threads
|
||||||
|
*/
|
||||||
|
void release(int num_threads = 1);
|
||||||
|
|
||||||
|
void release(Ticket *ticket);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_THREADPOOL_H
|
|
@ -0,0 +1,160 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
#ifndef SAMEDIFF_THREADS_H
|
||||||
|
#define SAMEDIFF_THREADS_H
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <openmp_pragmas.h>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <Environment.h>
|
||||||
|
#include <op_enums.h>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
class ThreadsHelper {
|
||||||
|
public:
|
||||||
|
static int numberOfThreads(int maxThreads, uint64_t numberOfElements);
|
||||||
|
static int numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y);
|
||||||
|
static int numberOfThreads3d(int maxThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z);
|
||||||
|
static int pickLoop2d(int numThreads, uint64_t iters_x, uint64_t iters_y);
|
||||||
|
static int pickLoop3d(int numThreads, uint64_t iters_x, uint64_t iters_y, uint64_t iters_z);
|
||||||
|
};
|
||||||
|
|
||||||
|
class Span {
|
||||||
|
private:
|
||||||
|
int64_t _startX, _stopX, _incX;
|
||||||
|
public:
|
||||||
|
Span(int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
~Span() = default;
|
||||||
|
|
||||||
|
int64_t startX() const;
|
||||||
|
int64_t stopX() const;
|
||||||
|
int64_t incX() const;
|
||||||
|
|
||||||
|
static Span build(uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
};
|
||||||
|
|
||||||
|
class Span2 {
|
||||||
|
private:
|
||||||
|
int64_t _startX, _stopX, _incX;
|
||||||
|
int64_t _startY, _stopY, _incY;
|
||||||
|
public:
|
||||||
|
Span2(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
|
||||||
|
~Span2() = default;
|
||||||
|
|
||||||
|
int64_t startX() const;
|
||||||
|
int64_t startY() const;
|
||||||
|
|
||||||
|
int64_t stopX() const;
|
||||||
|
int64_t stopY() const;
|
||||||
|
|
||||||
|
int64_t incX() const;
|
||||||
|
int64_t incY() const;
|
||||||
|
|
||||||
|
static Span2 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
|
||||||
|
};
|
||||||
|
|
||||||
|
class Span3 {
|
||||||
|
private:
|
||||||
|
int64_t _startX, _stopX, _incX;
|
||||||
|
int64_t _startY, _stopY, _incY;
|
||||||
|
int64_t _startZ, _stopZ, _incZ;
|
||||||
|
public:
|
||||||
|
Span3(int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z);
|
||||||
|
~Span3() = default;
|
||||||
|
|
||||||
|
int64_t startX() const;
|
||||||
|
int64_t startY() const;
|
||||||
|
int64_t startZ() const;
|
||||||
|
|
||||||
|
int64_t stopX() const;
|
||||||
|
int64_t stopY() const;
|
||||||
|
int64_t stopZ() const;
|
||||||
|
|
||||||
|
int64_t incX() const;
|
||||||
|
int64_t incY() const;
|
||||||
|
int64_t incZ() const;
|
||||||
|
|
||||||
|
static Span3 build(int loop, uint64_t thread_id, uint64_t num_threads, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z);
|
||||||
|
};
|
||||||
|
|
||||||
|
class Threads {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* This function executes 1 dimensional loop for a given number of threads
|
||||||
|
* PLEASE NOTE: this function can use smaller number of threads than requested.
|
||||||
|
*
|
||||||
|
* @param function
|
||||||
|
* @param numThreads
|
||||||
|
* @param start
|
||||||
|
* @param stop
|
||||||
|
* @param increment
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
|
static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param function
|
||||||
|
* @param numThreads
|
||||||
|
* @param start_x
|
||||||
|
* @param stop_x
|
||||||
|
* @param inc_x
|
||||||
|
* @param start_y
|
||||||
|
* @param stop_y
|
||||||
|
* @param inc_y
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads(), bool debug = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param function
|
||||||
|
* @param numThreads
|
||||||
|
* @param start_x
|
||||||
|
* @param stop_x
|
||||||
|
* @param inc_x
|
||||||
|
* @param start_y
|
||||||
|
* @param stop_y
|
||||||
|
* @param inc_y
|
||||||
|
* @param start_z
|
||||||
|
* @param stop_z
|
||||||
|
* @param inc_z
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param function
|
||||||
|
* @param numThreads
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static int parallel_do(FUNC_DO function, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
|
static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
|
static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //SAMEDIFF_THREADS_H
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SAMEDIFF_TICKET_H
|
||||||
|
#define SAMEDIFF_TICKET_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <execution/BlockingQueue.h>
|
||||||
|
#include <execution/CallableWithArguments.h>
|
||||||
|
#include <execution/CallableInterface.h>
|
||||||
|
#include <atomic>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
class Ticket {
|
||||||
|
private:
|
||||||
|
bool _acquired = false;
|
||||||
|
std::vector<BlockingQueue<CallableWithArguments*>*> _queues;
|
||||||
|
std::vector<CallableWithArguments*> _callables;
|
||||||
|
std::vector<CallableInterface*> _interfaces;
|
||||||
|
|
||||||
|
uint32_t _acquiredThreads = 0;
|
||||||
|
public:
|
||||||
|
explicit Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues);
|
||||||
|
Ticket();
|
||||||
|
~Ticket() = default;
|
||||||
|
|
||||||
|
bool acquired();
|
||||||
|
|
||||||
|
void acquiredThreads(uint32_t threads);
|
||||||
|
|
||||||
|
void attach(uint32_t thread_id, CallableInterface *interface);
|
||||||
|
|
||||||
|
// deprecated one
|
||||||
|
void enqueue(int thread_id, CallableWithArguments* callable);
|
||||||
|
|
||||||
|
void enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
void enqueue(uint32_t thread_id, uint32_t num_threads, double *lpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
|
||||||
|
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func);
|
||||||
|
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x);
|
||||||
|
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y);
|
||||||
|
void enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_, int64_t stop_z, int64_t inc_z);
|
||||||
|
|
||||||
|
void waitAndRelease();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_TICKET_H
|
|
@ -0,0 +1,73 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <execution/BlockingQueue.h>
|
||||||
|
#include <CallableWithArguments.h>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
template <typename T>
|
||||||
|
BlockingQueue<T>::BlockingQueue(int queueSize) {
|
||||||
|
_size = 0;
|
||||||
|
_available = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T BlockingQueue<T>::poll() {
|
||||||
|
// locking untill there's something within queue
|
||||||
|
std::unique_lock<std::mutex> lock(_lock);
|
||||||
|
_condition.wait(lock, [&]{ return this->_size.load() != 0; });
|
||||||
|
|
||||||
|
T t(std::move(_queue.front()));
|
||||||
|
_queue.pop();
|
||||||
|
_size--;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void BlockingQueue<T>::put(const T &t) {
|
||||||
|
{
|
||||||
|
// locking before push, unlocking after
|
||||||
|
std::unique_lock<std::mutex> lock(_lock);
|
||||||
|
_queue.push(t);
|
||||||
|
_size++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// notifying condition
|
||||||
|
_condition.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool BlockingQueue<T>::available() {
|
||||||
|
return _available.load();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void BlockingQueue<T>::markAvailable() {
|
||||||
|
_available = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void BlockingQueue<T>::markUnavailable() {
|
||||||
|
_available = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template class BlockingQueue<CallableWithArguments*>;
|
||||||
|
}
|
|
@ -0,0 +1,213 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <execution/CallableInterface.h>
|
||||||
|
#include <helpers/logger.h>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
CallableInterface::CallableInterface() {
|
||||||
|
// initial state is available
|
||||||
|
_available = true;
|
||||||
|
_filled = false;
|
||||||
|
_finished = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CallableInterface::available() {
|
||||||
|
return _available.load();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::markUnavailable() {
|
||||||
|
_available = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::markAvailable() {
|
||||||
|
_available = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::fill(int threadID, int numThreads, FUNC_DO func) {
|
||||||
|
_function_do = std::move(func);
|
||||||
|
|
||||||
|
_branch = 0;
|
||||||
|
_num_threads = numThreads;
|
||||||
|
_thread_id = threadID;
|
||||||
|
_finished = false;
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> l(_ms);
|
||||||
|
_filled = true;
|
||||||
|
}
|
||||||
|
_starter.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::fill(int threadID, int numThreads, FUNC_1D func, int64_t startX, int64_t stopX, int64_t incX) {
|
||||||
|
_function_1d = std::move(func);
|
||||||
|
_arguments[0] = startX;
|
||||||
|
_arguments[1] = stopX;
|
||||||
|
_arguments[2] = incX;
|
||||||
|
|
||||||
|
_branch = 1;
|
||||||
|
_num_threads = numThreads;
|
||||||
|
_thread_id = threadID;
|
||||||
|
_finished = false;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> l(_ms);
|
||||||
|
_filled = true;
|
||||||
|
}
|
||||||
|
_starter.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::fill(int threadID, int numThreads, FUNC_2D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y) {
|
||||||
|
_function_2d = std::move(func);
|
||||||
|
_arguments[0] = startX;
|
||||||
|
_arguments[1] = stopX;
|
||||||
|
_arguments[2] = incX;
|
||||||
|
_arguments[3] = start_y;
|
||||||
|
_arguments[4] = stop_y;
|
||||||
|
_arguments[5] = inc_y;
|
||||||
|
|
||||||
|
_branch = 2;
|
||||||
|
_num_threads = numThreads;
|
||||||
|
_thread_id = threadID;
|
||||||
|
_finished = false;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> l(_ms);
|
||||||
|
_filled = true;
|
||||||
|
}
|
||||||
|
_starter.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::fill(int threadID, int numThreads, FUNC_3D func, int64_t startX, int64_t stopX, int64_t incX, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) {
|
||||||
|
_function_3d = std::move(func);
|
||||||
|
_arguments[0] = startX;
|
||||||
|
_arguments[1] = stopX;
|
||||||
|
_arguments[2] = incX;
|
||||||
|
_arguments[3] = start_y;
|
||||||
|
_arguments[4] = stop_y;
|
||||||
|
_arguments[5] = inc_y;
|
||||||
|
_arguments[6] = start_z;
|
||||||
|
_arguments[7] = stop_z;
|
||||||
|
_arguments[8] = inc_z;
|
||||||
|
|
||||||
|
_branch = 3;
|
||||||
|
_num_threads = numThreads;
|
||||||
|
_thread_id = threadID;
|
||||||
|
_finished = false;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> l(_ms);
|
||||||
|
_filled = true;
|
||||||
|
}
|
||||||
|
_starter.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::fill(int threadID, int numThreads, int64_t *lptr, FUNC_RL func, int64_t startX, int64_t stopX, int64_t incX) {
|
||||||
|
_function_rl = std::move(func);
|
||||||
|
_arguments[0] = startX;
|
||||||
|
_arguments[1] = stopX;
|
||||||
|
_arguments[2] = incX;
|
||||||
|
|
||||||
|
_lptr = lptr;
|
||||||
|
|
||||||
|
_branch = 4;
|
||||||
|
_num_threads = numThreads;
|
||||||
|
_thread_id = threadID;
|
||||||
|
_finished = false;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> l(_ms);
|
||||||
|
_filled = true;
|
||||||
|
}
|
||||||
|
_starter.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::fill(int threadID, int numThreads, double *dptr, FUNC_RD func, int64_t startX, int64_t stopX, int64_t incX) {
|
||||||
|
_function_rd = std::move(func);
|
||||||
|
_arguments[0] = startX;
|
||||||
|
_arguments[1] = stopX;
|
||||||
|
_arguments[2] = incX;
|
||||||
|
|
||||||
|
_dptr = dptr;
|
||||||
|
|
||||||
|
_branch = 5;
|
||||||
|
_num_threads = numThreads;
|
||||||
|
_thread_id = threadID;
|
||||||
|
_finished = false;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> l(_ms);
|
||||||
|
_filled = true;
|
||||||
|
}
|
||||||
|
_starter.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::waitForTask() {
|
||||||
|
// block until task is available
|
||||||
|
std::unique_lock<std::mutex> lock(_ms);
|
||||||
|
_starter.wait(lock, [&]{ return _filled.load(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::waitForCompletion() {
|
||||||
|
//while (!_finished.load());
|
||||||
|
|
||||||
|
// block until finished
|
||||||
|
std::unique_lock<std::mutex> lock(_mf);
|
||||||
|
_finisher.wait(lock, [&] { return _finished.load(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::finish() {
|
||||||
|
// mark as finished
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> l(_mf);
|
||||||
|
_finished.store(true);
|
||||||
|
}
|
||||||
|
_finisher.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableInterface::execute() {
|
||||||
|
// mark it as consumed
|
||||||
|
_filled = false;
|
||||||
|
|
||||||
|
// actually executing op
|
||||||
|
switch (_branch) {
|
||||||
|
case 0:
|
||||||
|
_function_do(_thread_id, _num_threads);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
_function_1d(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
_function_2d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5]);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
_function_3d(_thread_id, _arguments[0], _arguments[1], _arguments[2], _arguments[3], _arguments[4], _arguments[5], _arguments[6], _arguments[7], _arguments[8]);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
_lptr[0] = _function_rl(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
_dptr[0] = _function_rd(_thread_id, _arguments[0], _arguments[1], _arguments[2]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// notify that thread finished the job
|
||||||
|
this->finish();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,103 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <execution/CallableWithArguments.h>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
CallableWithArguments::CallableWithArguments(FUNC_DO func, uint64_t thread_id, uint64_t numThreads) {
|
||||||
|
_function_do = func;
|
||||||
|
_finished = false;
|
||||||
|
_threadId = thread_id;
|
||||||
|
_numThreads = numThreads;
|
||||||
|
_dimensions = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
CallableWithArguments::CallableWithArguments(FUNC_3D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y, int64_t start_z, int64_t stop_z, int64_t increment_z) {
|
||||||
|
_function_3d = func;
|
||||||
|
_arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y, start_z, stop_z, increment_z};
|
||||||
|
_finished = false;
|
||||||
|
_threadId = thread_id;
|
||||||
|
_dimensions = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
CallableWithArguments::CallableWithArguments(FUNC_1D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x) {
|
||||||
|
_function_1d = func;
|
||||||
|
_arguments = {start_x, stop_x, increment_x};
|
||||||
|
_finished = false;
|
||||||
|
_threadId = thread_id;
|
||||||
|
_dimensions = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
CallableWithArguments::CallableWithArguments(FUNC_2D func, uint64_t thread_id, int64_t start_x, int64_t stop_x, int64_t increment_x, int64_t start_y, int64_t stop_y, int64_t increment_y) {
|
||||||
|
_function_2d = func;
|
||||||
|
_arguments = {start_x, stop_x, increment_x, start_y, stop_y, increment_y};
|
||||||
|
_finished = false;
|
||||||
|
_threadId = thread_id;
|
||||||
|
_dimensions = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
int CallableWithArguments::dimensions() {
|
||||||
|
return _dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t>& CallableWithArguments::arguments() {
|
||||||
|
return _arguments;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CallableWithArguments::finished() {
|
||||||
|
return _finished.load();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableWithArguments::finish() {
|
||||||
|
std::lock_guard<std::mutex> lock(_lock);
|
||||||
|
_finished = true;
|
||||||
|
_condition.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CallableWithArguments::waitUntilFinished() {
|
||||||
|
std::unique_lock<std::mutex> lock(_lock);
|
||||||
|
_condition.wait(lock, [&]{ return _finished.load(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FUNC_1D CallableWithArguments::function_1d() {
|
||||||
|
return _function_1d;
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNC_2D CallableWithArguments::function_2d() {
|
||||||
|
return _function_2d;
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNC_DO CallableWithArguments::function_do() {
|
||||||
|
return _function_do;
|
||||||
|
}
|
||||||
|
|
||||||
|
FUNC_3D CallableWithArguments::function_3d() {
|
||||||
|
return _function_3d;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t CallableWithArguments::threadId() {
|
||||||
|
return _threadId;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t CallableWithArguments::numThreads() {
|
||||||
|
return _numThreads;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,194 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <execution/ThreadPool.h>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <helpers/logger.h>
|
||||||
|
|
||||||
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
|
//#include <windows.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
|
||||||
|
// this function executed once per thread, it polls functions from queue, and executes them via wrapper
|
||||||
|
static void executionLoop_(int thread_id, BlockingQueue<CallableWithArguments*> *queue) {
|
||||||
|
while (true) {
|
||||||
|
// this method blocks until there's something within queue
|
||||||
|
auto c = queue->poll();
|
||||||
|
//nd4j_printf("ThreadPool: starting thread %i\n", c->threadId());
|
||||||
|
switch (c->dimensions()) {
|
||||||
|
case 0: {
|
||||||
|
c->function_do()(c->threadId(), c->numThreads());
|
||||||
|
c->finish();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 1: {
|
||||||
|
auto args = c->arguments();
|
||||||
|
c->function_1d()(c->threadId(), args[0], args[1], args[2]);
|
||||||
|
c->finish();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2: {
|
||||||
|
auto args = c->arguments();
|
||||||
|
c->function_2d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5]);
|
||||||
|
c->finish();
|
||||||
|
//nd4j_printf("ThreadPool: finished thread %i\n", c->threadId());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3: {
|
||||||
|
auto args = c->arguments();
|
||||||
|
c->function_3d()(c->threadId(), args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8]);
|
||||||
|
c->finish();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Don't know what to do with provided Callable");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void executionLoopWithInterface_(int thread_id, CallableInterface *c) {
|
||||||
|
while (true) {
|
||||||
|
// blocking here until there's something to do
|
||||||
|
c->waitForTask();
|
||||||
|
|
||||||
|
// execute whatever we have
|
||||||
|
c->execute();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ThreadPool::ThreadPool() {
|
||||||
|
// TODO: number of threads must reflect number of cores for UMA system. In case of NUMA it should be per-device pool
|
||||||
|
// FIXME: on mobile phones this feature must NOT be used
|
||||||
|
_available = nd4j::Environment::getInstance()->maxThreads();
|
||||||
|
|
||||||
|
_queues.resize(_available.load());
|
||||||
|
_threads.resize(_available.load());
|
||||||
|
_interfaces.resize(_available.load());
|
||||||
|
|
||||||
|
// creating threads here
|
||||||
|
for (int e = 0; e < _available.load(); e++) {
|
||||||
|
_queues[e] = new BlockingQueue<CallableWithArguments*>(2);
|
||||||
|
_interfaces[e] = new CallableInterface();
|
||||||
|
_threads[e] = new std::thread(executionLoopWithInterface_, e, _interfaces[e]);
|
||||||
|
_tickets.push(new Ticket());
|
||||||
|
// _threads[e] = new std::thread(executionLoop_, e, _queues[e]);
|
||||||
|
|
||||||
|
// TODO: add other platforms here as well
|
||||||
|
// now we must set affinity, and it's going to be platform-specific thing
|
||||||
|
#ifdef LINUX_BUILD
|
||||||
|
cpu_set_t cpuset;
|
||||||
|
CPU_ZERO(&cpuset);
|
||||||
|
CPU_SET(e, &cpuset);
|
||||||
|
int rc = pthread_setaffinity_np(_threads[e]->native_handle(), sizeof(cpu_set_t), &cpuset);
|
||||||
|
if (rc != 0)
|
||||||
|
throw std::runtime_error("Failed to set pthread affinity");
|
||||||
|
#endif
|
||||||
|
/*
|
||||||
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
|
// we can't set affinity to more than 64 cores
|
||||||
|
if (e <= 64) {
|
||||||
|
auto mask = (static_cast<DWORD_PTR>(1) << e);
|
||||||
|
auto result = SetThreadAffinityMask(_threads[e]->native_handle(), mask);
|
||||||
|
if (!result)
|
||||||
|
throw std::runtime_error("Failed to set pthread affinity");
|
||||||
|
}
|
||||||
|
|
||||||
|
// that's fine. no need for time_critical here
|
||||||
|
SetThreadPriority(_threads[e]->native_handle(), THREAD_PRIORITY_HIGHEST);
|
||||||
|
#endif
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ThreadPool::~ThreadPool() {
|
||||||
|
// TODO: implement this one properly
|
||||||
|
for (int e = 0; e < _queues.size(); e++) {
|
||||||
|
// stop each and every thread
|
||||||
|
|
||||||
|
// release queue and thread
|
||||||
|
//delete _queues[e];
|
||||||
|
//delete _threads[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::mutex _lmutex;
|
||||||
|
|
||||||
|
ThreadPool* ThreadPool::getInstance() {
|
||||||
|
std::unique_lock<std::mutex> lock(_lmutex);
|
||||||
|
if (!_INSTANCE)
|
||||||
|
_INSTANCE = new ThreadPool();
|
||||||
|
|
||||||
|
return _INSTANCE;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadPool::release(int numThreads) {
|
||||||
|
_available += numThreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ticket* ThreadPool::tryAcquire(int numThreads) {
|
||||||
|
//std::vector<BlockingQueue<CallableWithArguments*>*> queues;
|
||||||
|
|
||||||
|
Ticket *t = nullptr;
|
||||||
|
// we check for threads availability first
|
||||||
|
bool threaded = false;
|
||||||
|
{
|
||||||
|
// we lock before checking availability
|
||||||
|
std::unique_lock<std::mutex> lock(_lock);
|
||||||
|
if (_available >= numThreads) {
|
||||||
|
threaded = true;
|
||||||
|
_available -= numThreads;
|
||||||
|
|
||||||
|
// getting a ticket from the queue
|
||||||
|
t = _tickets.front();
|
||||||
|
_tickets.pop();
|
||||||
|
|
||||||
|
// ticket must contain information about number of threads for the current session
|
||||||
|
t->acquiredThreads(numThreads);
|
||||||
|
|
||||||
|
// filling ticket with executable interfaces
|
||||||
|
for (int e = 0, i = 0; e < _queues.size() && i < numThreads; e++) {
|
||||||
|
if (_interfaces[e]->available()) {
|
||||||
|
t->attach(i++, _interfaces[e]);
|
||||||
|
_interfaces[e]->markUnavailable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we either dispatch tasks to threads, or run single-threaded
|
||||||
|
if (threaded) {
|
||||||
|
return t;
|
||||||
|
} else {
|
||||||
|
// if there's no threads available - return nullptr
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ThreadPool::release(samediff::Ticket *ticket) {
|
||||||
|
// returning ticket back to the queue
|
||||||
|
std::unique_lock<std::mutex> lock(_lock);
|
||||||
|
_tickets.push(ticket);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ThreadPool* ThreadPool::_INSTANCE = 0;
|
||||||
|
}
|
|
@ -0,0 +1,641 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <execution/ThreadPool.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <thread>
|
||||||
|
#include <helpers/logger.h>
|
||||||
|
#include <templatemath.h>
|
||||||
|
#include <shape.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
|
||||||
|
int ThreadsHelper::numberOfThreads(int maxThreads, uint64_t numberOfElements) {
|
||||||
|
// let's see how many threads we actually need first
|
||||||
|
auto optimalThreads = nd4j::math::nd4j_max<uint64_t>(1, numberOfElements / 1024);
|
||||||
|
|
||||||
|
// now return the smallest value
|
||||||
|
return nd4j::math::nd4j_min<int>(optimalThreads, maxThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
Span3::Span3(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) {
|
||||||
|
_startX = startX;
|
||||||
|
_startY = startY;
|
||||||
|
_startZ = startZ;
|
||||||
|
_stopX = stopX;
|
||||||
|
_stopY = stopY;
|
||||||
|
_stopZ = stopZ;
|
||||||
|
_incX = incX;
|
||||||
|
_incY = incY;
|
||||||
|
_incZ = incZ;
|
||||||
|
}
|
||||||
|
|
||||||
|
Span3 Span3::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ) {
|
||||||
|
switch (loop) {
|
||||||
|
case 1: {
|
||||||
|
auto span = (stopX - startX) / numThreads;
|
||||||
|
auto s = span * threadID;
|
||||||
|
auto e = s + span;
|
||||||
|
if (threadID == numThreads - 1)
|
||||||
|
e = stopX;
|
||||||
|
|
||||||
|
return Span3(s, e, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2: {
|
||||||
|
auto span = (stopY - startY) / numThreads;
|
||||||
|
auto s = span * threadID;
|
||||||
|
auto e = s + span;
|
||||||
|
if (threadID == numThreads - 1)
|
||||||
|
e = stopY;
|
||||||
|
|
||||||
|
return Span3(startX, stopX, incX, s, e, incY, startZ, stopZ, incZ);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3: {
|
||||||
|
auto span = (stopZ - startZ) / numThreads;
|
||||||
|
auto s = span * threadID;
|
||||||
|
auto e = s + span;
|
||||||
|
if (threadID == numThreads - 1)
|
||||||
|
e = stopZ;
|
||||||
|
|
||||||
|
return Span3(startX, stopX, incX, startY, stopY, incY, s, e, incZ);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("");
|
||||||
|
}
|
||||||
|
return Span3(startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
Span::Span(int64_t startX, int64_t stopX, int64_t incX) {
|
||||||
|
_startX = startX;
|
||||||
|
_stopX = stopX;
|
||||||
|
_incX = incX;
|
||||||
|
}
|
||||||
|
|
||||||
|
Span Span::build(uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX) {
|
||||||
|
auto span = (stopX - startX) / numThreads;
|
||||||
|
auto s = span * threadID;
|
||||||
|
auto e = s + span;
|
||||||
|
if (threadID == numThreads - 1)
|
||||||
|
e = stopX;
|
||||||
|
|
||||||
|
return Span(s, e, incX);
|
||||||
|
}
|
||||||
|
|
||||||
|
Span2::Span2(int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) {
|
||||||
|
_startX = startX;
|
||||||
|
_startY = startY;
|
||||||
|
_stopX = stopX;
|
||||||
|
_stopY = stopY;
|
||||||
|
_incX = incX;
|
||||||
|
_incY = incY;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Span2 Span2::build(int loop, uint64_t threadID, uint64_t numThreads, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY) {
|
||||||
|
|
||||||
|
switch (loop) {
|
||||||
|
case 1: {
|
||||||
|
auto span = (stopX - startX) / numThreads;
|
||||||
|
auto s = span * threadID;
|
||||||
|
auto e = s + span;
|
||||||
|
if (threadID == numThreads - 1)
|
||||||
|
e = stopX;
|
||||||
|
|
||||||
|
return Span2(s, e, incX, startY, stopY, incY);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2: {
|
||||||
|
auto span = (stopY - startY) / numThreads;
|
||||||
|
auto s = span * threadID;
|
||||||
|
auto e = s + span;
|
||||||
|
if (threadID == numThreads - 1)
|
||||||
|
e = stopY;
|
||||||
|
|
||||||
|
return Span2(startX, stopX, incX, s, e, incY);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span::startX() const {
|
||||||
|
return _startX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span::stopX() const {
|
||||||
|
return _stopX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span::incX() const {
|
||||||
|
return _incX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span2::startX() const {
|
||||||
|
return _startX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span2::startY() const {
|
||||||
|
return _startY;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span2::stopX() const {
|
||||||
|
return _stopX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span2::stopY() const {
|
||||||
|
return _stopY;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span2::incX() const {
|
||||||
|
return _incX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span2::incY() const {
|
||||||
|
return _incY;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::startX() const {
|
||||||
|
return _startX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::startY() const {
|
||||||
|
return _startY;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::startZ() const {
|
||||||
|
return _startZ;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::stopX() const {
|
||||||
|
return _stopX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::stopY() const {
|
||||||
|
return _stopY;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::stopZ() const {
|
||||||
|
return _stopZ;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::incX() const {
|
||||||
|
return _incX;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::incY() const {
|
||||||
|
return _incY;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Span3::incZ() const {
|
||||||
|
return _incZ;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ThreadsHelper::pickLoop2d(int numThreads, uint64_t itersX, uint64_t itersY) {
|
||||||
|
// if one of dimensions is definitely too small - we just pick the other one
|
||||||
|
if (itersX < numThreads && itersY >= numThreads)
|
||||||
|
return 2;
|
||||||
|
if (itersY < numThreads && itersX >= numThreads)
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
// next step - we pick the most balanced dimension
|
||||||
|
auto remX = itersX % numThreads;
|
||||||
|
auto remY = itersY % numThreads;
|
||||||
|
auto splitY = itersY / numThreads;
|
||||||
|
|
||||||
|
// if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution
|
||||||
|
if (remX == 0)
|
||||||
|
return 1;
|
||||||
|
if (remY == 0)
|
||||||
|
return 2;
|
||||||
|
|
||||||
|
// if there's no loop without a remainder - we're picking one with smaller remainder
|
||||||
|
if (remX < remY)
|
||||||
|
return 1;
|
||||||
|
if (remY < remX && splitY >= 64) // we don't want too small splits over last dimension, or vectorization will fail
|
||||||
|
return 2;
|
||||||
|
// if loops are equally sized - give the preference to the first thread
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static int threads_(int maxThreads, uint64_t elements) {
|
||||||
|
|
||||||
|
if (elements == maxThreads) {
|
||||||
|
return maxThreads;
|
||||||
|
}
|
||||||
|
else if (elements > maxThreads) {
|
||||||
|
// if we have full load across thread, or at least half of threads can be utilized
|
||||||
|
auto rem = elements % maxThreads;
|
||||||
|
if (rem == 0 || rem >= maxThreads / 3)
|
||||||
|
return maxThreads;
|
||||||
|
else
|
||||||
|
return threads_(maxThreads - 1, elements);
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (elements < maxThreads) {
|
||||||
|
return elements;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ThreadsHelper::numberOfThreads2d(int maxThreads, uint64_t iters_x, uint64_t iters_y) {
|
||||||
|
// in some cases there's nothing to think about, part 1
|
||||||
|
if (iters_x < maxThreads && iters_y < maxThreads)
|
||||||
|
return nd4j::math::nd4j_max<int>(iters_x, iters_y);
|
||||||
|
|
||||||
|
auto remX = iters_x % maxThreads;
|
||||||
|
auto remY = iters_y % maxThreads;
|
||||||
|
|
||||||
|
// in some cases there's nothing to think about, part 2
|
||||||
|
if ((iters_x >= maxThreads && remX == 0 )|| (iters_y >= maxThreads && remY == 0))
|
||||||
|
return maxThreads;
|
||||||
|
|
||||||
|
// at this point we suppose that there's no loop perfectly matches number of our threads
|
||||||
|
// so let's pick something as equal as possible
|
||||||
|
if (iters_x > maxThreads || iters_y > maxThreads)
|
||||||
|
return maxThreads;
|
||||||
|
else
|
||||||
|
return numberOfThreads2d(maxThreads - 1, iters_x, iters_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
int ThreadsHelper::numberOfThreads3d(int maxThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) {
|
||||||
|
// we don't want to run underloaded threads
|
||||||
|
if (itersX * itersY * itersZ <= 32)
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
auto remX = itersX % maxThreads;
|
||||||
|
auto remY = itersY % maxThreads;
|
||||||
|
auto remZ = itersZ % maxThreads;
|
||||||
|
|
||||||
|
// if we have perfect balance across one of dimensions - just go for it
|
||||||
|
if ((itersX >= maxThreads && remX == 0) || (itersY >= maxThreads && remY == 0) || (itersZ >= maxThreads && remZ == 0))
|
||||||
|
return maxThreads;
|
||||||
|
|
||||||
|
int threadsX = 0, threadsY = 0, threadsZ = 0;
|
||||||
|
|
||||||
|
// now we look into possible number of
|
||||||
|
threadsX = threads_(maxThreads, itersX);
|
||||||
|
threadsY = threads_(maxThreads, itersY);
|
||||||
|
threadsZ = threads_(maxThreads, itersZ);
|
||||||
|
|
||||||
|
// we want to split as close to outer loop as possible, so checking it out first
|
||||||
|
if (threadsX >= threadsY && threadsX >= threadsZ)
|
||||||
|
return threadsX;
|
||||||
|
else if (threadsY >= threadsX && threadsY >= threadsZ)
|
||||||
|
return threadsY;
|
||||||
|
else if (threadsZ >= threadsX && threadsZ >= threadsY)
|
||||||
|
return threadsZ;
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ThreadsHelper::pickLoop3d(int numThreads, uint64_t itersX, uint64_t itersY, uint64_t itersZ) {
|
||||||
|
auto remX = itersX % numThreads;
|
||||||
|
auto remY = itersY % numThreads;
|
||||||
|
auto remZ = itersZ % numThreads;
|
||||||
|
|
||||||
|
auto splitX = itersX / numThreads;
|
||||||
|
auto splitY = itersY / numThreads;
|
||||||
|
auto splitZ = itersZ / numThreads;
|
||||||
|
|
||||||
|
// if there's no remainder left in some dimension - we're picking that dimension, because it'll be the most balanced work distribution
|
||||||
|
if (remX == 0)
|
||||||
|
return 1;
|
||||||
|
else if (remY == 0)
|
||||||
|
return 2;
|
||||||
|
else if (remZ == 0) // TODO: we don't want too smal splits over last dimension? or we do?
|
||||||
|
return 3;
|
||||||
|
|
||||||
|
if (itersX > numThreads)
|
||||||
|
return 1;
|
||||||
|
else if (itersY > numThreads)
|
||||||
|
return 2;
|
||||||
|
else if (itersZ > numThreads)
|
||||||
|
return 3;
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Threads::parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
|
||||||
|
if (start > stop)
|
||||||
|
throw std::runtime_error("Threads::parallel_for got start > stop");
|
||||||
|
|
||||||
|
auto delta = (stop - start);
|
||||||
|
|
||||||
|
if (numThreads > delta)
|
||||||
|
numThreads = delta;
|
||||||
|
|
||||||
|
if (numThreads == 0)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
// shortcut
|
||||||
|
if (numThreads == 1) {
|
||||||
|
function(0, start, stop, increment);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
// if we got our threads - we'll run our jobs here
|
||||||
|
auto span = delta / numThreads;
|
||||||
|
|
||||||
|
for (uint32_t e = 0; e < numThreads; e++) {
|
||||||
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = start_ + span;
|
||||||
|
|
||||||
|
// last thread will process tail
|
||||||
|
if (e == numThreads - 1)
|
||||||
|
stop_ = stop;
|
||||||
|
|
||||||
|
// putting the task into the queue for a given thread
|
||||||
|
ticket->enqueue(e, numThreads, function, start_, stop_, increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// block and wait till all threads finished the job
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
// we tell that parallelism request succeeded
|
||||||
|
return numThreads;
|
||||||
|
} else {
|
||||||
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, start, stop, increment);
|
||||||
|
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int Threads::parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment, uint32_t numThreads) {
|
||||||
|
if (start > stop)
|
||||||
|
throw std::runtime_error("Threads::parallel_for got start > stop");
|
||||||
|
|
||||||
|
auto delta = (stop - start);
|
||||||
|
|
||||||
|
// in some cases we just fire func as is
|
||||||
|
if (delta == 0 || numThreads == 1) {
|
||||||
|
function(0, start, stop, increment);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto numElements = delta / increment;
|
||||||
|
|
||||||
|
// we decide what's optimal number of threads we need here, and execute it in parallel_tad.
|
||||||
|
numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements);
|
||||||
|
return parallel_tad(function, start, stop, increment, numThreads);
|
||||||
|
}
|
||||||
|
|
||||||
|
int Threads::parallel_for(FUNC_2D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, uint64_t numThreads, bool debug) {
|
||||||
|
if (startX > stopX)
|
||||||
|
throw std::runtime_error("Threads::parallel_for got startX > stopX");
|
||||||
|
|
||||||
|
if (startY > stopY)
|
||||||
|
throw std::runtime_error("Threads::parallel_for got startY > stopY");
|
||||||
|
|
||||||
|
// number of elements per loop
|
||||||
|
auto delta_x = (stopX - startX);
|
||||||
|
auto delta_y = (stopY - startY);
|
||||||
|
|
||||||
|
// number of iterations per loop
|
||||||
|
auto itersX = delta_x / incX;
|
||||||
|
auto itersY = delta_y / incY;
|
||||||
|
|
||||||
|
// total number of iterations
|
||||||
|
auto iters_t = itersX * itersY;
|
||||||
|
|
||||||
|
// we are checking the case of number of requested threads was smaller
|
||||||
|
numThreads = ThreadsHelper::numberOfThreads2d(numThreads, itersX, itersY);
|
||||||
|
|
||||||
|
// basic shortcut for no-threading cases
|
||||||
|
if (numThreads == 1) {
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have couple of scenarios:
|
||||||
|
// either we split workload along 1st loop, or 2nd
|
||||||
|
auto splitLoop = ThreadsHelper::pickLoop2d(numThreads, itersX, itersY);
|
||||||
|
|
||||||
|
// for debug mode we execute things inplace, without any threads
|
||||||
|
if (debug) {
|
||||||
|
for (int e = 0; e < numThreads; e++) {
|
||||||
|
auto span = Span2::build(splitLoop, e, numThreads, startX, stopX, incX, startY, stopY, incY);
|
||||||
|
|
||||||
|
function(e, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
|
||||||
|
}
|
||||||
|
|
||||||
|
// but we still mimic multithreaded execution
|
||||||
|
return numThreads;
|
||||||
|
} else {
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
|
||||||
|
for (int e = 0; e < numThreads; e++) {
|
||||||
|
auto threadId = numThreads - e - 1;
|
||||||
|
auto span = Span2::build(splitLoop, threadId, numThreads, startX, stopX, incX, startY, stopY, incY);
|
||||||
|
|
||||||
|
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY());
|
||||||
|
}
|
||||||
|
|
||||||
|
// block until all threads finish their job
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
return numThreads;
|
||||||
|
} else {
|
||||||
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY);
|
||||||
|
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int Threads::parallel_for(FUNC_3D function, int64_t startX, int64_t stopX, int64_t incX, int64_t startY, int64_t stopY, int64_t incY, int64_t startZ, int64_t stopZ, int64_t incZ, uint64_t numThreads) {
|
||||||
|
if (startX > stopX)
|
||||||
|
throw std::runtime_error("Threads::parallel_for got startX > stopX");
|
||||||
|
|
||||||
|
if (startY > stopY)
|
||||||
|
throw std::runtime_error("Threads::parallel_for got startY > stopY");
|
||||||
|
|
||||||
|
if (startZ > stopZ)
|
||||||
|
throw std::runtime_error("Threads::parallel_for got startZ > stopZ");
|
||||||
|
|
||||||
|
auto delta_x = stopX - startX;
|
||||||
|
auto delta_y = stopY - startY;
|
||||||
|
auto delta_z = stopZ - startZ;
|
||||||
|
|
||||||
|
auto itersX = delta_x / incX;
|
||||||
|
auto itersY = delta_y / incY;
|
||||||
|
auto itersZ = delta_z / incZ;
|
||||||
|
|
||||||
|
numThreads = 1; //ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ);
|
||||||
|
if (numThreads == 1) {
|
||||||
|
// loop is too small - executing function as is
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
auto splitLoop = ThreadsHelper::pickLoop3d(numThreads, itersX, itersY, itersZ);
|
||||||
|
|
||||||
|
for (int e = 0; e < numThreads; e++) {
|
||||||
|
auto thread_id = numThreads - e - 1;
|
||||||
|
auto span = Span3::build(splitLoop, thread_id, numThreads, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
|
||||||
|
ticket->enqueue(e, numThreads, function, span.startX(), span.stopX(), span.incX(), span.startY(), span.stopY(), span.incY(), span.startZ(), span.stopZ(), span.incZ());
|
||||||
|
}
|
||||||
|
|
||||||
|
// block until we're done
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
// we tell that parallelism request succeeded
|
||||||
|
return numThreads;
|
||||||
|
} else {
|
||||||
|
// if there were no threads available - we'll execute function right within current thread
|
||||||
|
function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ);
|
||||||
|
|
||||||
|
// we tell that parallelism request declined
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
int Threads::parallel_do(FUNC_DO function, uint64_t numThreads) {
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||||
|
if (ticket != nullptr) {
|
||||||
|
|
||||||
|
// submit tasks one by one
|
||||||
|
for (uint64_t e = 0; e < numThreads - 1; e++)
|
||||||
|
ticket->enqueue(e, numThreads, function);
|
||||||
|
|
||||||
|
function(numThreads - 1, numThreads);
|
||||||
|
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
return numThreads;
|
||||||
|
} else {
|
||||||
|
// if there's no threads available - we'll execute function sequentially one by one
|
||||||
|
for (uint64_t e = 0; e < numThreads; e++)
|
||||||
|
function(e, numThreads);
|
||||||
|
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return numThreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t Threads::parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) {
|
||||||
|
if (start > stop)
|
||||||
|
throw std::runtime_error("Threads::parallel_long got start > stop");
|
||||||
|
|
||||||
|
auto delta = (stop - start);
|
||||||
|
if (delta == 0 || numThreads == 1)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
|
auto numElements = delta / increment;
|
||||||
|
|
||||||
|
// we decide what's optimal number of threads we need here, and execute it
|
||||||
|
numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements);
|
||||||
|
if (numThreads == 1)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||||
|
if (ticket == nullptr)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
|
// create temporary array
|
||||||
|
int64_t intermediatery[256];
|
||||||
|
auto span = delta / numThreads;
|
||||||
|
|
||||||
|
// execute threads in parallel
|
||||||
|
for (uint32_t e = 0; e < numThreads; e++) {
|
||||||
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = span * (e + 1) + start;
|
||||||
|
|
||||||
|
if (e == numThreads - 1)
|
||||||
|
intermediatery[e] = function(e, start_, stop, increment);
|
||||||
|
else
|
||||||
|
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
// aggregate results in single thread
|
||||||
|
for (uint64_t e = 1; e < numThreads; e++)
|
||||||
|
intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]);
|
||||||
|
|
||||||
|
// return accumulated result
|
||||||
|
return intermediatery[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
double Threads::parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment, uint64_t numThreads) {
|
||||||
|
if (start > stop)
|
||||||
|
throw std::runtime_error("Threads::parallel_long got start > stop");
|
||||||
|
|
||||||
|
auto delta = (stop - start);
|
||||||
|
if (delta == 0 || numThreads == 1)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
|
auto numElements = delta / increment;
|
||||||
|
|
||||||
|
// we decide what's optimal number of threads we need here, and execute it
|
||||||
|
numThreads = ThreadsHelper::numberOfThreads(numThreads, numElements);
|
||||||
|
if (numThreads == 1)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
|
auto ticket = ThreadPool::getInstance()->tryAcquire(numThreads - 1);
|
||||||
|
if (ticket == nullptr)
|
||||||
|
return function(0, start, stop, increment);
|
||||||
|
|
||||||
|
// create temporary array
|
||||||
|
double intermediatery[256];
|
||||||
|
auto span = delta / numThreads;
|
||||||
|
|
||||||
|
// execute threads in parallel
|
||||||
|
for (uint32_t e = 0; e < numThreads; e++) {
|
||||||
|
auto start_ = span * e + start;
|
||||||
|
auto stop_ = span * (e + 1) + start;
|
||||||
|
|
||||||
|
if (e == numThreads - 1)
|
||||||
|
intermediatery[e] = function(e, start_, stop, increment);
|
||||||
|
else
|
||||||
|
ticket->enqueue(e, numThreads, &intermediatery[e], function, start_, stop_, increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket->waitAndRelease();
|
||||||
|
|
||||||
|
// aggregate results in single thread
|
||||||
|
for (uint64_t e = 1; e < numThreads; e++)
|
||||||
|
intermediatery[0] = aggregator(intermediatery[0], intermediatery[e]);
|
||||||
|
|
||||||
|
// return accumulated result
|
||||||
|
return intermediatery[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,94 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <execution/Ticket.h>
|
||||||
|
#include <execution/ThreadPool.h>
|
||||||
|
#include <helpers/logger.h>
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
namespace samediff {
|
||||||
|
Ticket::Ticket(const std::vector<BlockingQueue<CallableWithArguments*>*> &queues) {
|
||||||
|
_acquired = true;
|
||||||
|
_queues = queues;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ticket::Ticket() {
|
||||||
|
_acquired = true;
|
||||||
|
_interfaces.resize(nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Ticket::acquired() {
|
||||||
|
return _acquired;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::enqueue(int thread_id, samediff::CallableWithArguments *callable) {
|
||||||
|
_queues[thread_id]->put(callable);
|
||||||
|
_callables.emplace_back(callable);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_DO func) {
|
||||||
|
_interfaces[thread_id]->fill(thread_id, num_threads, func);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_1D func, int64_t start_x, int64_t stop_x, int64_t inc_x) {
|
||||||
|
_interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, int64_t *lpt, FUNC_RL func, int64_t start_x, int64_t stop_x, int64_t inc_x) {
|
||||||
|
_interfaces[thread_id]->fill(thread_id, num_threads, lpt, func, start_x, stop_x, inc_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, double *dpt, FUNC_RD func, int64_t start_x, int64_t stop_x, int64_t inc_x) {
|
||||||
|
_interfaces[thread_id]->fill(thread_id, num_threads, dpt, func, start_x, stop_x, inc_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_2D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y) {
|
||||||
|
_interfaces[thread_id]->fill(thread_id, num_threads, std::move(func), start_x, stop_x, inc_x, start_y, stop_y, inc_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::enqueue(uint32_t thread_id, uint32_t num_threads, FUNC_3D func, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z) {
|
||||||
|
_interfaces[thread_id]->fill(thread_id, num_threads, func, start_x, stop_x, inc_x, start_y, stop_y, inc_y, start_z, stop_z, inc_z);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::acquiredThreads(uint32_t threads) {
|
||||||
|
_acquiredThreads = threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Ticket::waitAndRelease() {
|
||||||
|
for (uint32_t e = 0; e < this->_acquiredThreads; e++) {
|
||||||
|
// block until finished
|
||||||
|
_interfaces[e]->waitForCompletion();
|
||||||
|
|
||||||
|
// mark available
|
||||||
|
_interfaces[e]->markAvailable();
|
||||||
|
|
||||||
|
// increment availability counter
|
||||||
|
ThreadPool::getInstance()->release();
|
||||||
|
}
|
||||||
|
|
||||||
|
// return this ticket back to the pool
|
||||||
|
ThreadPool::getInstance()->release(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void Ticket::attach(uint32_t thread_id, samediff::CallableInterface *interface) {
|
||||||
|
_interfaces[thread_id] = interface;
|
||||||
|
}
|
||||||
|
}
|
|
@ -232,6 +232,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
static nd4j::ops::DeclarableOp* buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar);
|
static nd4j::ops::DeclarableOp* buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar);
|
||||||
|
static void deleteOpByType(OpType opType, void *op);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <graph/Graph.h>
|
#include <graph/Graph.h>
|
||||||
|
#include <array/DataTypeUtils.h>
|
||||||
#include <helpers/EnumUtils.h>
|
#include <helpers/EnumUtils.h>
|
||||||
#include <graph/FlatUtils.h>
|
#include <graph/FlatUtils.h>
|
||||||
#include <NativeOps.h>
|
#include <NativeOps.h>
|
||||||
|
@ -154,7 +155,7 @@ namespace nd4j {
|
||||||
Nd4jLong *newShape = nullptr;
|
Nd4jLong *newShape = nullptr;
|
||||||
|
|
||||||
// if that's scalar output - we don't care about previous node
|
// if that's scalar output - we don't care about previous node
|
||||||
if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == MAX_INT)) {
|
if (node->getDimensions()->size() == 0 || (node->getDimensions()->size() == 1 && node->getDimensions()->at(0) == nd4j::DataTypeUtils::max<int>())) {
|
||||||
newShape = new Nd4jLong[8];
|
newShape = new Nd4jLong[8];
|
||||||
|
|
||||||
newShape[0] = 2;
|
newShape[0] = 2;
|
||||||
|
|
|
@ -682,8 +682,9 @@ namespace nd4j {
|
||||||
if (_protoContext != nullptr)
|
if (_protoContext != nullptr)
|
||||||
delete _protoContext;
|
delete _protoContext;
|
||||||
|
|
||||||
if (_isDeductable && _customOp != nullptr)
|
if (_isDeductable && _customOp != nullptr) {
|
||||||
delete _customOp;
|
Node::deleteOpByType(_opType, _customOp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int nd4j::graph::Node::getRewindNode() {
|
int nd4j::graph::Node::getRewindNode() {
|
||||||
|
@ -710,6 +711,70 @@ namespace nd4j {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void nd4j::graph::Node::deleteOpByType(OpType opType, void *op) {
|
||||||
|
switch (opType) {
|
||||||
|
case OpType_PAIRWISE:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyPairwiseTransformOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_PAIRWISE_BOOL:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyPairwiseTransformBoolOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_TRANSFORM_STRICT:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyTransformStrictOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_TRANSFORM_SAME:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyTransformSameOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_TRANSFORM_FLOAT:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyTransformFloatOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_TRANSFORM_BOOL:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyTransformBoolOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_SCALAR:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyScalarOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_SCALAR_BOOL:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyScalarBoolOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_REDUCE_3:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyReduce3Op*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_REDUCE_SAME:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyReduceSameOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_REDUCE_FLOAT:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyReduceFloatOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_REDUCE_LONG:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyReduceLongOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_REDUCE_BOOL:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyReduceBoolOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_INDEX_REDUCE:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyIndexReduceOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_SUMMARYSTATS:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyStatsOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_RANDOM:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyRandomOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_BROADCAST:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyBroadcastOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_BROADCAST_BOOL:
|
||||||
|
delete reinterpret_cast<nd4j::ops::LegacyBroadcastBoolOp*>(op);
|
||||||
|
break;
|
||||||
|
case OpType_CUSTOM:
|
||||||
|
delete reinterpret_cast<nd4j::ops::DeclarableOp*>(op);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("Bad opType passed in");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::ops::DeclarableOp* nd4j::graph::Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) {
|
nd4j::ops::DeclarableOp* nd4j::graph::Node::buildOpByType(OpType opType, int numInputs, int numIArgs, int numTArgs, int opNum, NDArray *scalar) {
|
||||||
switch (opType) {
|
switch (opType) {
|
||||||
case OpType_PAIRWISE:
|
case OpType_PAIRWISE:
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -721,7 +721,7 @@ namespace shape {
|
||||||
INLINEDEF void TAD::createOffsets() {
|
INLINEDEF void TAD::createOffsets() {
|
||||||
this->tadOffsets = new Nd4jLong[this->numTads];
|
this->tadOffsets = new Nd4jLong[this->numTads];
|
||||||
uint nT = this->numTads;
|
uint nT = this->numTads;
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for(uint i = 0; i < nT; i++)
|
for(uint i = 0; i < nT; i++)
|
||||||
this->tadOffsets[i] = this->tadOffset(i);
|
this->tadOffsets[i] = this->tadOffset(i);
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include "../OpBenchmark.h"
|
#include "../OpBenchmark.h"
|
||||||
#include <helpers/BlasHelper.h>
|
|
||||||
#include <MmulHelper.h>
|
#include <MmulHelper.h>
|
||||||
|
|
||||||
#ifndef DEV_TESTS_MATRIXBENCHMARK_H
|
#ifndef DEV_TESTS_MATRIXBENCHMARK_H
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
#include <helpers/BlasHelper.h>
|
#include <helpers/BlasHelper.h>
|
||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
@ -74,26 +75,28 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2))
|
auto func = PRAGMA_THREADS_FOR_2D { ;
|
||||||
for(uint row = 0; row < M; ++row) {
|
for (auto row = start_x; row < stop_x; row += inc_x) {
|
||||||
for(uint col = 0; col < N; ++col) {
|
for (auto col = start_y; col < stop_y; col += inc_y) {
|
||||||
|
T3 *c = flagC ? (C + row + col * ldc) : (C + row * ldc + col);
|
||||||
|
T3 val = 0;
|
||||||
|
|
||||||
T3* c = flagC ? (C + row + col * ldc) : (C + row * ldc + col);
|
PRAGMA_OMP_SIMD
|
||||||
T3 val = 0;
|
for (uint i = 0; i < K; ++i) {
|
||||||
|
T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda);
|
||||||
|
T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i);
|
||||||
|
val += alphaZ * a * b;
|
||||||
|
}
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
if (betaZ)
|
||||||
for(uint i = 0; i < K; ++i) {
|
*c = val + betaZ * *c;
|
||||||
T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda);
|
else
|
||||||
T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i);
|
*c = val;
|
||||||
val += alphaZ * a * b;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if(betaZ)
|
samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1);
|
||||||
*c = val + betaZ * *c;
|
|
||||||
else
|
|
||||||
*c = val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -108,24 +111,27 @@ static void usualGemv(const char aOrder, const int M, const int N, const double
|
||||||
|
|
||||||
const bool flagA = aOrder == 'f';
|
const bool flagA = aOrder == 'f';
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for(int row = 0; row < M; ++row) {
|
for (auto row = start; row < stop; row += increment) {
|
||||||
|
|
||||||
T3* y = Y + row * incy;
|
T3 *y = Y + row * incy;
|
||||||
T3 val = 0;
|
T3 val = 0;
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for(int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i);
|
T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i);
|
||||||
T3 x = *(X + i * incx);
|
T3 x = *(X + i * incx);
|
||||||
val += alphaZ * a * x;
|
val += alphaZ * a * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (betaZ)
|
||||||
|
*y = val + betaZ * *y;
|
||||||
|
else
|
||||||
|
*y = val;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if(betaZ)
|
samediff::Threads::parallel_for(func, 0, M);
|
||||||
*y = val + betaZ * *y;
|
|
||||||
else
|
|
||||||
*y = val;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -141,7 +147,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX,
|
||||||
T3 sum = 0;
|
T3 sum = 0;
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum))
|
||||||
for(int i = 0; i < length; ++i)
|
for(int i = 0; i < length; ++i)
|
||||||
sum = sum + X[i * incx] * Y[i * incy];
|
sum += X[i * incx] * Y[i * incy];
|
||||||
|
|
||||||
*Z = alphaZ * sum + betaZ * *Z;
|
*Z = alphaZ * sum + betaZ * *Z;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <TrueBroadcastHelper.h>
|
#include <TrueBroadcastHelper.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
|
|
@ -44,62 +44,67 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
const Nd4jLong* tadShape = shape::shapeOf(const_cast<Nd4jLong*>(tadShapeInfo));
|
const Nd4jLong* tadShape = shape::shapeOf(const_cast<Nd4jLong*>(tadShapeInfo));
|
||||||
const Nd4jLong* tadStride = shape::stride(const_cast<Nd4jLong*>(tadShapeInfo));
|
const Nd4jLong* tadStride = shape::stride(const_cast<Nd4jLong*>(tadShapeInfo));
|
||||||
|
|
||||||
int tadsPerThread = zLen / TAD_THRESHOLD;
|
|
||||||
int numThreads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
|
||||||
numThreads = nd4j::math::nd4j_min<int>(numThreads, omp_get_max_threads());
|
|
||||||
|
|
||||||
switch (kindOfLoop) {
|
switch (kindOfLoop) {
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case nd4j::LoopKind::EWS1: {
|
case nd4j::LoopKind::EWS1: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint j = 0; j < tadLen; j++) {
|
for (uint j = 0; j < tadLen; j++) {
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[j], j);
|
functions::indexreduce::IndexValue<X> comp(tad[j], j);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
z[i] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
z[i] = (Z) indexValue.index;
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case nd4j::LoopKind::EWSNONZERO: {
|
case nd4j::LoopKind::EWSNONZERO: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint j = 0; j < tadLen; j++) {
|
for (uint j = 0; j < tadLen; j++) {
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j);
|
functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
z[i * zEws] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
z[i * zEws] = (Z) indexValue.index;
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
//*********************************************//
|
//*********************************************//
|
||||||
case nd4j::LoopKind::RANK1: {
|
case nd4j::LoopKind::RANK1: {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint i0 = 0; i0 < tadLen; ++i0) {
|
for (uint i0 = 0; i0 < tadLen; ++i0) {
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[i0 * tadStride[0]], i0);
|
functions::indexreduce::IndexValue<X> comp(tad[i0 * tadStride[0]], i0);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
z[i] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
z[i] = (Z) indexValue.index;
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -108,22 +113,25 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
Nd4jLong newStride[2];
|
Nd4jLong newStride[2];
|
||||||
shape::updateStrides(2, tadShape, newStride, 'c');
|
shape::updateStrides(2, tadShape, newStride, 'c');
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
||||||
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
||||||
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1];
|
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1];
|
||||||
const auto tadIndex = i0 * newStride[0] + i1;
|
const auto tadIndex = i0 * newStride[0] + i1;
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
z[i] = (Z) indexValue.index;
|
z[i] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -132,24 +140,27 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
Nd4jLong newStride[3];
|
Nd4jLong newStride[3];
|
||||||
shape::updateStrides(3, tadShape, newStride, 'c');
|
shape::updateStrides(3, tadShape, newStride, 'c');
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
||||||
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
||||||
for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
|
for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
|
||||||
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2];
|
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2];
|
||||||
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2;
|
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2;
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
z[i] = (Z) indexValue.index;
|
z[i] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -158,26 +169,29 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
Nd4jLong newStride[4];
|
Nd4jLong newStride[4];
|
||||||
shape::updateStrides(4, tadShape, newStride, 'c');
|
shape::updateStrides(4, tadShape, newStride, 'c');
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
||||||
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
||||||
for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
|
for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
|
||||||
for (uint i3 = 0; i3 < tadShape[3]; ++i3) {
|
for (uint i3 = 0; i3 < tadShape[3]; ++i3) {
|
||||||
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3];
|
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3];
|
||||||
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3;
|
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3;
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
z[i] = (Z) indexValue.index;
|
z[i] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -186,28 +200,31 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
Nd4jLong newStride[5];
|
Nd4jLong newStride[5];
|
||||||
shape::updateStrides(5, tadShape, newStride, 'c');
|
shape::updateStrides(5, tadShape, newStride, 'c');
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; ++i) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
for (uint i0 = 0; i0 < tadShape[0]; ++i0) {
|
||||||
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
for (uint i1 = 0; i1 < tadShape[1]; ++i1) {
|
||||||
for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
|
for (uint i2 = 0; i2 < tadShape[2]; ++i2) {
|
||||||
for (uint i3 = 0; i3 < tadShape[3]; ++i3) {
|
for (uint i3 = 0; i3 < tadShape[3]; ++i3) {
|
||||||
for (uint i4 = 0; i4 < tadShape[4]; ++i4) {
|
for (uint i4 = 0; i4 < tadShape[4]; ++i4) {
|
||||||
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4];
|
const auto tadOffset = i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4];
|
||||||
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3 * newStride[3] + i4;
|
const auto tadIndex = i0 * newStride[0] + i1 * newStride[1] + i2 * newStride[2] + i3 * newStride[3] + i4;
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], tadIndex);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
z[i] = (Z) indexValue.index;
|
z[i] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -216,19 +233,22 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
uint castZShapeInfo[MAX_RANK];
|
uint castZShapeInfo[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint j = 0; j < tadLen; j++) {
|
for (uint j = 0; j < tadLen; j++) {
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j);
|
functions::indexreduce::IndexValue<X> comp(tad[j * tadEws], j);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
|
||||||
|
z[zOffset] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
z[zOffset] = (Z) indexValue.index;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -237,19 +257,22 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
uint castTadShapeInfo[MAX_RANK];
|
uint castTadShapeInfo[MAX_RANK];
|
||||||
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint j = 0; j < tadLen; j++) {
|
for (uint j = 0; j < tadLen; j++) {
|
||||||
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
|
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
|
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
z[i * zEws] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
z[i * zEws] = (Z) indexValue.index;
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -260,20 +283,23 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
|
||||||
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
const bool canCastTad = nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (uint i = 0; i < zLen; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto tad = const_cast<X*>(x) + tadOffsets[i];
|
auto tad = const_cast<X *>(x) + tadOffsets[i];
|
||||||
auto indexValue = OpType::startingIndexValue(tad);
|
auto indexValue = OpType::startingIndexValue(tad);
|
||||||
|
|
||||||
for (uint j = 0; j < tadLen; j++) {
|
for (uint j = 0; j < tadLen; j++) {
|
||||||
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
|
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
|
||||||
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
|
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
|
||||||
indexValue = OpType::update(indexValue, comp, extraParams);
|
indexValue = OpType::update(indexValue, comp, extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
|
||||||
|
z[zOffset] = (Z) indexValue.index;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
|
samediff::Threads::parallel_tad(func, 0, zLen);
|
||||||
z[zOffset] = (Z) indexValue.index;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,31 +28,31 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) {
|
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) {
|
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,31 +28,31 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) {
|
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) {
|
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,31 +28,31 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) {
|
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) {
|
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,31 +28,31 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, int* dims, int dimsLen, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams) {
|
void Reduction3Loops<X,Z>::innerloopReduce3All(X* x, Nd4jLong* xShapeInfo, X* y, Nd4jLong* yShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams);
|
Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams) {
|
void Reduction3Loops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, int* dims, int dimsLen, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dims, dimsLen, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams) {
|
void Reduction3Loops<X, Y>::wrapperAll(const int opNum, X *x, Nd4jLong *xShapeInfo, X *y, Nd4jLong *yShapeInfo, Y *z, Nd4jLong *zShapeInfo, Nd4jLong* xTadShapeInfo, Nd4jLong* xTadOffsets, Nd4jLong* yTadShapeInfo, Nd4jLong* yTadOffsets, Y* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce3All, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xTadOffsets, yTadShapeInfo, yTadOffsets, extraParams, start, stop), REDUCE3_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,9 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionBoolLoops<X, Z>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) {
|
void ReductionBoolLoops<X, Z>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,9 +36,9 @@ namespace nd4j {
|
||||||
void ReductionBoolLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
void ReductionBoolLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets,
|
Nd4jLong *tadOffsets,
|
||||||
X *extraParams) {
|
X *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), REDUCE_BOOL_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,18 +28,18 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) {
|
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets, Y *extraParams) {
|
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,18 +28,18 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) {
|
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets, Y *extraParams) {
|
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,18 +28,18 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) {
|
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets, Y *extraParams) {
|
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,18 +28,18 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams) {
|
void ReductionFloatLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets, Y *extraParams) {
|
Nd4jLong *tadOffsets, Y *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_FLOAT_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,18 +33,18 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionLongLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z *z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) {
|
void ReductionLongLoops<X, Z>::innerloopReduce(X * x, Nd4jLong* xShapeInfo, Z *z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionLongLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
void ReductionLongLoops<X, Y>::wrapper(const int opNum, X *x, Nd4jLong *xShapeInfo, Y *z,
|
||||||
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets, X *extraParams) {
|
Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams ), REDUCE_LONG_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,9 @@ namespace nd4j {
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionSameLoops<X>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) {
|
void ReductionSameLoops<X>::innerloopReduce(X* x, Nd4jLong* xShapeInfo, X* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams);
|
ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,13 +36,13 @@ namespace nd4j {
|
||||||
void ReductionSameLoops<X>::wrapper(const int opNum, X *vx, Nd4jLong *xShapeInfo, X *vz,
|
void ReductionSameLoops<X>::wrapper(const int opNum, X *vx, Nd4jLong *xShapeInfo, X *vz,
|
||||||
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets,
|
Nd4jLong *tadOffsets,
|
||||||
X *vextraParams) {
|
X *vextraParams, int64_t start, int64_t stop) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), REDUCE_SAME_OPS);
|
DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_SAME_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <execution/LaunchContext.h>
|
#include <execution/LaunchContext.h>
|
||||||
#include <specials.h>
|
#include <specials.h>
|
||||||
#include <logger.h>
|
#include <logger.h>
|
||||||
|
#include <ops/ops.h>
|
||||||
// #include <cuda_runtime.h>
|
// #include <cuda_runtime.h>
|
||||||
// #include <cuda.h>
|
// #include <cuda.h>
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ namespace nd4j {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMV<float>() {
|
bool BlasHelper::hasGEMV<float>() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasSgemv;
|
return _hasSgemv;
|
||||||
|
@ -83,7 +83,7 @@ namespace nd4j {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMV<double>() {
|
bool BlasHelper::hasGEMV<double>() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasDgemv;
|
return _hasDgemv;
|
||||||
|
@ -132,14 +132,14 @@ namespace nd4j {
|
||||||
|
|
||||||
bool BlasHelper::hasGEMV(const nd4j::DataType dtype) {
|
bool BlasHelper::hasGEMV(const nd4j::DataType dtype) {
|
||||||
if(dtype == DataType::FLOAT32) {
|
if(dtype == DataType::FLOAT32) {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasSgemv;
|
return _hasSgemv;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
if(dtype == DataType::DOUBLE) {
|
if(dtype == DataType::DOUBLE) {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasDgemv;
|
return _hasDgemv;
|
||||||
|
@ -150,7 +150,7 @@ namespace nd4j {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMM<float>() {
|
bool BlasHelper::hasGEMM<float>() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasSgemm;
|
return _hasSgemm;
|
||||||
|
@ -159,7 +159,7 @@ namespace nd4j {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMM<double>() {
|
bool BlasHelper::hasGEMM<double>() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasDgemm;
|
return _hasDgemm;
|
||||||
|
@ -208,14 +208,14 @@ namespace nd4j {
|
||||||
|
|
||||||
bool BlasHelper:: hasGEMM(const nd4j::DataType dtype) {
|
bool BlasHelper:: hasGEMM(const nd4j::DataType dtype) {
|
||||||
if(dtype == DataType::FLOAT32) {
|
if(dtype == DataType::FLOAT32) {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasSgemm;
|
return _hasSgemm;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
if(dtype == DataType::DOUBLE) {
|
if(dtype == DataType::DOUBLE) {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
return _hasDgemm;
|
return _hasDgemm;
|
||||||
|
@ -276,14 +276,14 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
CblasSgemv BlasHelper::sgemv() {
|
CblasSgemv BlasHelper::sgemv() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__)|| defined(HAVE_OPENBLAS)
|
||||||
return (CblasSgemv)&cblas_sgemv;
|
return (CblasSgemv)&cblas_sgemv;
|
||||||
#else
|
#else
|
||||||
return this->cblasSgemv;
|
return this->cblasSgemv;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CblasDgemv BlasHelper::dgemv() {
|
CblasDgemv BlasHelper::dgemv() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return (CblasDgemv)&cblas_dgemv;
|
return (CblasDgemv)&cblas_dgemv;
|
||||||
#else
|
#else
|
||||||
return this->cblasDgemv;
|
return this->cblasDgemv;
|
||||||
|
@ -291,7 +291,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
CblasSgemm BlasHelper::sgemm() {
|
CblasSgemm BlasHelper::sgemm() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return (CblasSgemm)&cblas_sgemm;
|
return (CblasSgemm)&cblas_sgemm;
|
||||||
#else
|
#else
|
||||||
return this->cblasSgemm;
|
return this->cblasSgemm;
|
||||||
|
@ -299,7 +299,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
CblasDgemm BlasHelper::dgemm() {
|
CblasDgemm BlasHelper::dgemm() {
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_MKLDNN) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return (CblasDgemm)&cblas_dgemm;
|
return (CblasDgemm)&cblas_dgemm;
|
||||||
#else
|
#else
|
||||||
return this->cblasDgemm;
|
return this->cblasDgemm;
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
#include <ops/declarable/headers/parity_ops.h>
|
#include <ops/declarable/headers/parity_ops.h>
|
||||||
#include <helpers/DebugInfo.h>
|
#include <helpers/DebugInfo.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
DebugInfo DebugHelper::debugStatistics(NDArray const* input) {
|
DebugInfo DebugHelper::debugStatistics(NDArray const* input) {
|
||||||
|
@ -88,11 +89,18 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) reduction(+:_nanCount,_infCount,_m
|
||||||
}
|
}
|
||||||
*info = {_minValue, _maxValue, _meanValue / input->lengthOf(), _stdDevValue, _zeroCount, _positiveCount, _negativeCount, _infCount, _nanCount};
|
*info = {_minValue, _maxValue, _meanValue / input->lengthOf(), _stdDevValue, _zeroCount, _positiveCount, _negativeCount, _infCount, _nanCount};
|
||||||
_stdDevValue = 0; //math::nd4j_sqrt<double, double>(info->_stdDevValue / (input->lengthOf() - 1));
|
_stdDevValue = 0; //math::nd4j_sqrt<double, double>(info->_stdDevValue / (input->lengthOf() - 1));
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule (static) reduction(+:_stdDevValue))
|
|
||||||
for (Nd4jLong e = 0; e < input->lengthOf(); e++) {
|
auto func = PRAGMA_REDUCE_DOUBLE {
|
||||||
double current = input->e<double>(e);
|
auto _stdDevValue = 0.0;
|
||||||
_stdDevValue += (info->_meanValue - current) * (info->_meanValue - current); //info->_minValue;
|
for (auto e = start; e < stop; e++) {
|
||||||
}
|
double current = input->e<double>(e);
|
||||||
|
_stdDevValue += (info->_meanValue - current) * (info->_meanValue - current); //info->_minValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return _stdDevValue;
|
||||||
|
};
|
||||||
|
_stdDevValue = samediff::Threads::parallel_double(func, LAMBDA_AD { return _old + _new; }, 0, input->lengthOf());
|
||||||
|
|
||||||
info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf());
|
info->_stdDevValue = math::nd4j_sqrt<double, double>(_stdDevValue / input->lengthOf());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,13 +33,11 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
|
||||||
switch(loss) {
|
switch(loss) {
|
||||||
|
|
||||||
case MEAN:
|
case MEAN:
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(numInGradArrs > 1)
|
|
||||||
for(int i = 0; i < numInGradArrs; ++i)
|
for(int i = 0; i < numInGradArrs; ++i)
|
||||||
*gradArrs[i] = 1. / gradArrs[i]->lengthOf();
|
*gradArrs[i] = 1. / gradArrs[i]->lengthOf();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case SUM:
|
case SUM:
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(numInGradArrs > 1)
|
|
||||||
for(int i = 0; i < numInGradArrs; ++i)
|
for(int i = 0; i < numInGradArrs; ++i)
|
||||||
*gradArrs[i] = 1.;
|
*gradArrs[i] = 1.;
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -45,7 +45,7 @@ OmpLaunchHelper::OmpLaunchHelper(const Nd4jLong N, float desiredNumThreads) {
|
||||||
else
|
else
|
||||||
desiredNumThreads = nd4j::math::nd4j_min<int>(omp_get_max_threads(), desiredNumThreads);
|
desiredNumThreads = nd4j::math::nd4j_min<int>(omp_get_max_threads(), desiredNumThreads);
|
||||||
#else
|
#else
|
||||||
desiredNumThreads = 1;
|
desiredNumThreads = nd4j::Environment::getInstance()->maxThreads();
|
||||||
#endif
|
#endif
|
||||||
_numThreads = nd4j::math::nd4j_min<int>(N / maxItersPerThread, desiredNumThreads);
|
_numThreads = nd4j::math::nd4j_min<int>(N / maxItersPerThread, desiredNumThreads);
|
||||||
}
|
}
|
||||||
|
@ -75,7 +75,7 @@ Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) {
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
return betterThreads(N, omp_get_max_threads());
|
return betterThreads(N, omp_get_max_threads());
|
||||||
#else
|
#else
|
||||||
return 1;
|
return betterThreads(N, nd4j::Environment::getInstance()->maxThreads());;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ Nd4jLong OmpLaunchHelper::betterSpan(Nd4jLong N) {
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
auto maxThreads = omp_get_max_threads();
|
auto maxThreads = omp_get_max_threads();
|
||||||
#else
|
#else
|
||||||
auto maxThreads = 1;
|
auto maxThreads = nd4j::Environment::getInstance()->maxThreads();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// if there's only 1 thread allowed - nothing to do here
|
// if there's only 1 thread allowed - nothing to do here
|
||||||
|
|
|
@ -1,66 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifndef LIBND4J_AGGREGATES_H
|
|
||||||
#define LIBND4J_AGGREGATES_H
|
|
||||||
|
|
||||||
#include <ops/aggregate_ops.h>
|
|
||||||
#include <helpers/DebugHelper.h>
|
|
||||||
#include <helpers/helper_ptrmap.h>
|
|
||||||
|
|
||||||
namespace functions {
|
|
||||||
namespace aggregate {
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
class AggregatedFunction {
|
|
||||||
|
|
||||||
public:
|
|
||||||
#ifdef __CUDACC__
|
|
||||||
template<typename OpClass>
|
|
||||||
__device__ static void execCuda(X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments);
|
|
||||||
|
|
||||||
__device__ static void execCuda(int opNum, X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments);
|
|
||||||
|
|
||||||
__device__ static void aggregateBatch(int numAggregates, int opNum, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments);
|
|
||||||
|
|
||||||
__host__ static void aggregateBatchKernelGeneric(dim3& launchDims, cudaStream_t *stream, int opNum, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments);
|
|
||||||
|
|
||||||
__host__ static void aggregateKernelGeneric(dim3& launchDims, cudaStream_t *stream, int opNum, void **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, void *realArguments, int numRealArguments);
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template<typename OpClass>
|
|
||||||
inline static void exec(X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments) {
|
|
||||||
OpClass::executeAggregate(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline static void exec(int opNum, X **arguments, int numArguments, Nd4jLong **shapeArguments, int numShapeArguments, int *indexArguments, int numIndexArguments, int **intArrays, int numIntArrays, X *realArguments, int numRealArguments) {
|
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), AGGREGATE_OPS);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __CUDACC__
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif //LIBND4J_AGGREGATES_H
|
|
|
@ -91,7 +91,7 @@ namespace functions {
|
||||||
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#else
|
||||||
|
|
||||||
static void execInverse(int opNum,
|
static void execInverse(int opNum,
|
||||||
void *x,
|
void *x,
|
||||||
|
@ -105,7 +105,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
static void exec(int opNum,
|
static void exec(int opNum,
|
||||||
void *x,
|
void *x,
|
||||||
|
@ -119,7 +121,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* CPU execution
|
* CPU execution
|
||||||
|
@ -144,7 +148,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static void execInverse(void *x,
|
static void execInverse(void *x,
|
||||||
|
@ -158,7 +164,10 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
||||||
|
|
||||||
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
#endif
|
#else
|
||||||
|
|
||||||
static void exec(int opNum,
|
static void exec(int opNum,
|
||||||
void *x,
|
void *x,
|
||||||
|
@ -103,7 +103,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
static void execInverse(int opNum,
|
static void execInverse(int opNum,
|
||||||
void *x,
|
void *x,
|
||||||
|
@ -117,7 +119,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* CPU execution
|
* CPU execution
|
||||||
|
@ -142,7 +146,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static void execInverse(void *x,
|
static void execInverse(void *x,
|
||||||
|
@ -156,7 +162,10 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
||||||
|
|
||||||
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
#endif
|
#else
|
||||||
|
|
||||||
static void exec(int opNum,
|
static void exec(int opNum,
|
||||||
void *x,
|
void *x,
|
||||||
|
@ -103,7 +103,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
static void execInverse(int opNum,
|
static void execInverse(int opNum,
|
||||||
void *x,
|
void *x,
|
||||||
|
@ -117,7 +119,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* CPU execution
|
* CPU execution
|
||||||
|
@ -142,7 +146,9 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static void execInverse(void *x,
|
static void execInverse(void *x,
|
||||||
|
@ -156,7 +162,10 @@ namespace functions {
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ);
|
Nd4jLong *tadOffsetZ,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop);
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -43,7 +44,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x,
|
DISPATCH_BY_OPNUM_TTT(execInverse, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
|
@ -55,7 +58,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset), BROADCAST_OPS);
|
zTadOffset, start, stop), BROADCAST_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
|
@ -71,7 +74,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
|
@ -83,7 +88,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset), BROADCAST_OPS);
|
zTadOffset, start, stop), BROADCAST_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
|
@ -99,7 +104,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<Y *>(vy);
|
auto y = reinterpret_cast<Y *>(vy);
|
||||||
|
@ -131,10 +138,6 @@ namespace functions {
|
||||||
auto lenZ = shape::length(zTadShapeInfo);
|
auto lenZ = shape::length(zTadShapeInfo);
|
||||||
auto lenY = shape::length(yShapeInfo);
|
auto lenY = shape::length(yShapeInfo);
|
||||||
|
|
||||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
|
||||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
|
||||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
|
||||||
|
|
||||||
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
||||||
|
@ -142,19 +145,17 @@ namespace functions {
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
auto oX = x + tadOffsets[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(oX[f], y[f]);
|
oZ[f] = OpType::op(oX[f], y[f]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO){
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO){
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
|
@ -164,13 +165,10 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
|
@ -182,70 +180,61 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
|
@ -253,17 +242,15 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -285,7 +272,9 @@ namespace functions {
|
||||||
Nd4jLong *yTadShapeInfo,
|
Nd4jLong *yTadShapeInfo,
|
||||||
Nd4jLong *yTadOffset,
|
Nd4jLong *yTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<Y *>(vy);
|
auto y = reinterpret_cast<Y *>(vy);
|
||||||
|
@ -319,7 +308,7 @@ namespace functions {
|
||||||
|
|
||||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
||||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
@ -328,8 +317,7 @@ namespace functions {
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if(kindOfLoop == nd4j::LoopKind::EWS1) {
|
if(kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (unsigned int i = 0; i < tads; i++) {
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
|
@ -339,24 +327,20 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
|
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oY = x + tadOffsets[i];
|
auto oY = x + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
|
@ -365,73 +349,63 @@ namespace functions {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(x[offset], oY[offset]);
|
oZ[offset] = OpType::op(x[offset], oY[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
|
@ -439,20 +413,18 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
|
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -43,7 +44,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
|
@ -55,7 +58,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset), BROADCAST_BOOL_OPS);
|
zTadOffset, start, stop), BROADCAST_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
|
@ -71,7 +74,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x,
|
DISPATCH_BY_OPNUM_TT(execInverse, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
|
@ -83,7 +88,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset), BROADCAST_BOOL_OPS);
|
zTadOffset, start, stop), BROADCAST_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
|
@ -99,7 +104,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
|
@ -133,7 +140,7 @@ namespace functions {
|
||||||
|
|
||||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
|
@ -142,10 +149,9 @@ namespace functions {
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
|
@ -153,101 +159,86 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
// TODO: cover this codebranch with tests
|
|
||||||
// all this stuff already happens within thread
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(oX[offset], y[offset]);
|
oZ[offset] = OpType::op(oX[offset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
|
@ -255,20 +246,18 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,7 +275,9 @@ namespace functions {
|
||||||
Nd4jLong *yTadShapeInfo,
|
Nd4jLong *yTadShapeInfo,
|
||||||
Nd4jLong *yTadOffset,
|
Nd4jLong *yTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
|
@ -320,7 +311,7 @@ namespace functions {
|
||||||
|
|
||||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
||||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
@ -329,8 +320,7 @@ namespace functions {
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
|
@ -340,8 +330,7 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
|
@ -355,14 +344,10 @@ namespace functions {
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
// TODO: cover this codebranch with tests
|
|
||||||
// all this stuff already happens within thread
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
|
@ -377,15 +362,13 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
||||||
}
|
}
|
||||||
|
@ -398,15 +381,13 @@ namespace functions {
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
||||||
}
|
}
|
||||||
|
@ -419,16 +400,14 @@ namespace functions {
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -442,9 +421,7 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -43,7 +44,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
|
@ -55,7 +58,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset), BROADCAST_INT_OPS);
|
zTadOffset, start, stop), BROADCAST_INT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
@ -71,7 +74,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x,
|
DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
|
@ -83,7 +88,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset), BROADCAST_INT_OPS);
|
zTadOffset, start, stop), BROADCAST_INT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
@ -99,7 +104,9 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
|
@ -133,7 +140,7 @@ namespace functions {
|
||||||
|
|
||||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
|
@ -142,112 +149,95 @@ namespace functions {
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(oX[f], y[f]);
|
oZ[f] = OpType::op(oX[f], y[f]);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
// TODO: cover this codebranch with tests
|
|
||||||
// all this stuff already happens within thread
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(oX[offset], y[offset]);
|
oZ[offset] = OpType::op(oX[offset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
|
@ -255,20 +245,18 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oX = x + tadOffsets[i];
|
auto oX = x + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,7 +274,9 @@ namespace functions {
|
||||||
Nd4jLong *yTadShapeInfo,
|
Nd4jLong *yTadShapeInfo,
|
||||||
Nd4jLong *yTadOffset,
|
Nd4jLong *yTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset) {
|
Nd4jLong *zTadOffset,
|
||||||
|
uint64_t start,
|
||||||
|
uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
|
@ -320,7 +310,7 @@ namespace functions {
|
||||||
|
|
||||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
threads = nd4j::math::nd4j_min<int>(threads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
||||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
@ -329,46 +319,39 @@ namespace functions {
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
|
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(x[f], oY[f]);
|
oZ[f] = OpType::op(x[f], oY[f]);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint f = 0; f < tadLength; f++)
|
for (uint f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
|
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
|
|
||||||
// TODO: cover this codebranch with tests
|
|
||||||
// all this stuff already happens within thread
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(x[offset], oY[offset]);
|
oZ[offset] = OpType::op(x[offset], oY[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
|
||||||
|
|
||||||
|
@ -377,64 +360,54 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
|
||||||
|
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
uint tadShapeInfoZCast[MAX_RANK];
|
uint tadShapeInfoZCast[MAX_RANK];
|
||||||
|
@ -442,9 +415,7 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
for (auto i = start; i < stop; i ++) {
|
||||||
for (int i = 0; i < tads; i++) {
|
|
||||||
|
|
||||||
auto oZ = z + zTadOffset[i];
|
auto oZ = z + zTadOffset[i];
|
||||||
auto oY = y + tadOffsets[i];
|
auto oY = y + tadOffsets[i];
|
||||||
|
|
||||||
|
@ -455,7 +426,7 @@ namespace functions {
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
|
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include <Loops.h>
|
#include <Loops.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
#include "../legacy_ops.h"
|
#include "../legacy_ops.h"
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
@ -44,8 +45,7 @@ void IndexReduce<X,Y>::exec(const int opNum,
|
||||||
void *z, Nd4jLong *zShapeInfo,
|
void *z, Nd4jLong *zShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
|
||||||
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -64,42 +64,41 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
IndexValue<X> intermediatery[64];
|
||||||
|
for (int e = 0; e < maxThreads; e++)
|
||||||
|
intermediatery[e].index = -1;
|
||||||
|
|
||||||
if (xEws == 1) {
|
if (xEws == 1) {
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
||||||
auto local = OpType::startingIndexValue(x);
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
|
|
||||||
auto ulen = info.getItersPerThread(threadNum);
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
IndexValue<X> curr(x[i], i);
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
||||||
IndexValue<X> curr(x[i + threadOffset], i + threadOffset);
|
|
||||||
local = OpType::update(local, curr, extraParams);
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
|
||||||
|
|
||||||
|
for (int e = 0; e < maxThreads; e++)
|
||||||
|
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
|
||||||
startingIndex = OpType::update(startingIndex, local, extraParams);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
||||||
auto local = OpType::startingIndexValue(x);
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
|
|
||||||
auto ulen = info.getItersPerThread(threadNum);
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
IndexValue<X> curr(x[offset], i);
|
||||||
auto offset = shape::indexOffset(threadOffset + i, xShapeInfo, xShapeInfoCast, canCastX);
|
intermediatery[thread_id] = OpType::update(intermediatery[thread_id], curr, extraParams);
|
||||||
IndexValue<X> curr(x[offset], threadOffset + i);
|
|
||||||
local = OpType::update(local, curr, extraParams);
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
maxThreads = samediff::Threads::parallel_for(func, 0, len, 1, maxThreads);
|
||||||
startingIndex = OpType::update(startingIndex, local, extraParams);
|
|
||||||
}
|
for (int e = 0; e < maxThreads; e++)
|
||||||
|
startingIndex = OpType::update(startingIndex, intermediatery[e], extraParams);
|
||||||
}
|
}
|
||||||
return startingIndex.index;
|
return startingIndex.index;
|
||||||
}
|
}
|
||||||
|
@ -124,9 +123,10 @@ void IndexReduce<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto indexValue = OpType::startingIndexValue(x);
|
const auto indexValue = OpType::startingIndexValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < zLen; i++)
|
for (uint i = 0; i < zLen; i++)
|
||||||
z[i] = (Z) indexValue.index;;
|
z[i] = (Z) indexValue.index;
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include <helpers/shape.h>
|
#include <helpers/shape.h>
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <OmpLaunchHelper.h>
|
#include <OmpLaunchHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -42,7 +43,9 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong n) {
|
Nd4jLong n,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
||||||
xEws,
|
xEws,
|
||||||
y,
|
y,
|
||||||
|
@ -50,7 +53,7 @@ namespace functions {
|
||||||
z,
|
z,
|
||||||
zEws,
|
zEws,
|
||||||
extraParams,
|
extraParams,
|
||||||
n), PAIRWISE_TRANSFORM_OPS);
|
n, start, stop), PAIRWISE_TRANSFORM_OPS);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,48 +64,24 @@ namespace functions {
|
||||||
void *vy, Nd4jLong yEws,
|
void *vy, Nd4jLong yEws,
|
||||||
void *vz, Nd4jLong zEws,
|
void *vz, Nd4jLong zEws,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
const Nd4jLong n) {
|
const Nd4jLong n,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<Y *>(vy);
|
auto y = reinterpret_cast<Y *>(vy);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(n);
|
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i] = OpType::op(x[i], y[i], extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto yi = y + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i] = OpType::op(xi[i], yi[i], extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto yi = y + yEws*threadOffset;
|
|
||||||
auto zi = z + zEws*threadOffset;
|
|
||||||
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,14 +94,16 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
yShapeInfo,
|
yShapeInfo,
|
||||||
z,
|
z,
|
||||||
zShapeInfo,
|
zShapeInfo,
|
||||||
extraParams),
|
extraParams, start, stop),
|
||||||
PAIRWISE_TRANSFORM_OPS);
|
PAIRWISE_TRANSFORM_OPS);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -136,7 +117,9 @@ namespace functions {
|
||||||
Nd4jLong* yShapeInfo,
|
Nd4jLong* yShapeInfo,
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong* zShapeInfo,
|
Nd4jLong* zShapeInfo,
|
||||||
void *vextraParams) {
|
void *vextraParams,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<Y *>(vy);
|
auto y = reinterpret_cast<Y *>(vy);
|
||||||
|
@ -148,7 +131,6 @@ namespace functions {
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(n);
|
|
||||||
|
|
||||||
if (shape::isScalar(yShapeInfo)) {
|
if (shape::isScalar(yShapeInfo)) {
|
||||||
|
|
||||||
|
@ -156,38 +138,22 @@ namespace functions {
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for(auto i = start; i < stop; i++) {
|
||||||
{
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadNum = omp_get_thread_num();
|
z[offset] = OpType::op(x[offset], y[0], extraParams);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
};
|
||||||
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for(unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], y[0], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for(auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for(unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -198,96 +164,63 @@ namespace functions {
|
||||||
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
||||||
|
|
||||||
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
|
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
|
||||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
|
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop);
|
||||||
}
|
}
|
||||||
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
|
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
|
||||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
|
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
z[offset] = OpType::op(x[offset], y[offset], extraParams);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
|
@ -295,20 +228,13 @@ namespace functions {
|
||||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||||
PRAGMA_OMP_SIMD
|
};
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,106 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// Created by remote on 2018-09-20.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/ops.h>
|
|
||||||
#include <loops/pairwise_transform.h>
|
|
||||||
#include <types/types.h>
|
|
||||||
#include <templatemath.h>
|
|
||||||
#include <helpers/shape.h>
|
|
||||||
#include <op_boilerplate.h>
|
|
||||||
#include <OmpLaunchHelper.h>
|
|
||||||
|
|
||||||
using namespace simdOps;
|
|
||||||
|
|
||||||
namespace functions {
|
|
||||||
namespace pairwise_transforms {
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
void PairWiseTransform<X, Y, Z>::exec(
|
|
||||||
const int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong xEws,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong yEws,
|
|
||||||
void *z,
|
|
||||||
Nd4jLong zEws,
|
|
||||||
void *extraParams,
|
|
||||||
Nd4jLong n) {
|
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
|
||||||
xEws,
|
|
||||||
y,
|
|
||||||
yEws,
|
|
||||||
z,
|
|
||||||
zEws,
|
|
||||||
extraParams,
|
|
||||||
n), PAIRWISE_TRANSFORM_OPS);
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
template <typename OpType>
|
|
||||||
void PairWiseTransform<X, Y, Z>::exec(void *vx, Nd4jLong xEws,
|
|
||||||
void *vy, Nd4jLong yEws,
|
|
||||||
void *vz, Nd4jLong zEws,
|
|
||||||
void *vextraParams,
|
|
||||||
const Nd4jLong n) {
|
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
|
||||||
auto y = reinterpret_cast<Y *>(vy);
|
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(n);
|
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto yi = y + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
zi[i] = OpType::op(xi[i], yi[i], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto yi = y + yEws*threadOffset;
|
|
||||||
auto zi = z + zEws*threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
#include <OmpLaunchHelper.h>
|
#include <OmpLaunchHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -38,7 +39,9 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong n) {
|
Nd4jLong n,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
||||||
xEws,
|
xEws,
|
||||||
y,
|
y,
|
||||||
|
@ -46,7 +49,7 @@ namespace functions {
|
||||||
z,
|
z,
|
||||||
zEws,
|
zEws,
|
||||||
extraParams,
|
extraParams,
|
||||||
n), PAIRWISE_BOOL_OPS);
|
n, start, stop), PAIRWISE_BOOL_OPS);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,46 +63,24 @@ namespace functions {
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
const Nd4jLong n) {
|
const Nd4jLong n,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(n);
|
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i] = OpType::op(x[i], y[i], extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto yi = y + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
zi[i] = OpType::op(xi[i], yi[i], extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto yi = y + yEws*threadOffset;
|
|
||||||
auto zi = z + zEws*threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,14 +93,16 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
yShapeInfo,
|
yShapeInfo,
|
||||||
z,
|
z,
|
||||||
zShapeInfo,
|
zShapeInfo,
|
||||||
extraParams),
|
extraParams, start, stop),
|
||||||
PAIRWISE_BOOL_OPS);
|
PAIRWISE_BOOL_OPS);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -129,7 +112,9 @@ namespace functions {
|
||||||
void PairWiseBoolTransform<X, Z>::exec(void *vx, Nd4jLong* xShapeInfo,
|
void PairWiseBoolTransform<X, Z>::exec(void *vx, Nd4jLong* xShapeInfo,
|
||||||
void *vy, Nd4jLong* yShapeInfo,
|
void *vy, Nd4jLong* yShapeInfo,
|
||||||
void *vz, Nd4jLong* zShapeInfo,
|
void *vz, Nd4jLong* zShapeInfo,
|
||||||
void *vextraParams) {
|
void *vextraParams,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
|
@ -141,8 +126,6 @@ namespace functions {
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(n);
|
|
||||||
|
|
||||||
if (shape::isScalar(yShapeInfo)) {
|
if (shape::isScalar(yShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
|
@ -150,37 +133,22 @@ namespace functions {
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for(auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
z[offset] = OpType::op(x[offset], y[0], extraParams);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
};
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for(Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], y[0], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for(auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for(Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -189,96 +157,62 @@ namespace functions {
|
||||||
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
||||||
|
|
||||||
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
|
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
|
||||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
|
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop);
|
||||||
}
|
}
|
||||||
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
|
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
|
||||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
|
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
z[offset] = OpType::op(x[offset], y[offset], extraParams);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
};
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
|
@ -286,20 +220,13 @@ namespace functions {
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||||
PRAGMA_OMP_SIMD
|
};
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
#include <OmpLaunchHelper.h>
|
#include <OmpLaunchHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -38,7 +39,9 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong n) {
|
Nd4jLong n,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
||||||
xEws,
|
xEws,
|
||||||
y,
|
y,
|
||||||
|
@ -46,7 +49,7 @@ namespace functions {
|
||||||
z,
|
z,
|
||||||
zEws,
|
zEws,
|
||||||
extraParams,
|
extraParams,
|
||||||
n), PAIRWISE_INT_OPS);
|
n, start, stop), PAIRWISE_INT_OPS);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,46 +63,24 @@ namespace functions {
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
const Nd4jLong n) {
|
const Nd4jLong n,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(n);
|
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i] = OpType::op(x[i], y[i], extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto yi = y + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
zi[i] = OpType::op(xi[i], yi[i], extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i*zEws] = OpType::op(x[i*xEws], y[i*yEws], extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto yi = y + yEws*threadOffset;
|
|
||||||
auto zi = z + zEws*threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,14 +93,16 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
y,
|
y,
|
||||||
yShapeInfo,
|
yShapeInfo,
|
||||||
z,
|
z,
|
||||||
zShapeInfo,
|
zShapeInfo,
|
||||||
extraParams),
|
extraParams, start, stop),
|
||||||
PAIRWISE_INT_OPS);
|
PAIRWISE_INT_OPS);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -129,7 +112,9 @@ namespace functions {
|
||||||
void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo,
|
void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo,
|
||||||
void *vy, Nd4jLong* yShapeInfo,
|
void *vy, Nd4jLong* yShapeInfo,
|
||||||
void *vz, Nd4jLong* zShapeInfo,
|
void *vz, Nd4jLong* zShapeInfo,
|
||||||
void *vextraParams) {
|
void *vextraParams,
|
||||||
|
const uint64_t start,
|
||||||
|
const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
|
@ -141,46 +126,28 @@ namespace functions {
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(n);
|
|
||||||
|
|
||||||
if (shape::isScalar(yShapeInfo)) {
|
if (shape::isScalar(yShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for(auto i = start; i < stop; i++) {
|
||||||
{
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadNum = omp_get_thread_num();
|
z[offset] = OpType::op(x[offset], y[0], extraParams);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
};
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for(Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], y[0], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for(auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for(Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -189,96 +156,63 @@ namespace functions {
|
||||||
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
||||||
|
|
||||||
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
|
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
|
||||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
|
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n, start, stop);
|
||||||
}
|
}
|
||||||
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
|
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
|
||||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
|
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo), start, stop);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
z[offset] = OpType::op(x[offset], y[offset], extraParams);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
};
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
|
@ -286,20 +220,13 @@ namespace functions {
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||||
PRAGMA_OMP_SIMD
|
};
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,28 +52,22 @@ namespace functions {
|
||||||
|
|
||||||
auto length = shape::length(zShapeInfo);
|
auto length = shape::length(zShapeInfo);
|
||||||
|
|
||||||
// nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state);
|
|
||||||
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
||||||
nd4j::OmpLaunchHelper info(length);
|
|
||||||
|
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
|
@ -82,19 +76,16 @@ namespace functions {
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
else if (shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
|
@ -103,19 +94,16 @@ namespace functions {
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
else if (shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
|
@ -124,19 +112,16 @@ namespace functions {
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < info.getItersPerThread(threadNum); i++) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -147,20 +132,17 @@ namespace functions {
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments);
|
z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -184,41 +166,34 @@ namespace functions {
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
||||||
nd4j::OmpLaunchHelper info(length);
|
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
|
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
|
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,25 +207,21 @@ namespace functions {
|
||||||
|
|
||||||
auto length = shape::length(zShapeInfo);
|
auto length = shape::length(zShapeInfo);
|
||||||
|
|
||||||
//nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state);
|
|
||||||
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
||||||
nd4j::OmpLaunchHelper info(length);
|
nd4j::OmpLaunchHelper info(length);
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
{
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
for (uint64_t i = start; i < stop; i += increment) {
|
||||||
auto offset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments);
|
z[offset] = OpClass::op(i, length, rng, extraArguments);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, length, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
|
|
|
@ -55,7 +55,7 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < length; i++)
|
for (uint i = 0; i < length; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
return;
|
return;
|
||||||
|
@ -65,25 +65,14 @@ namespace functions {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
X start = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
|
||||||
X intermediate[256];
|
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
|
for (auto i = 0; i < length; i++)
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
|
||||||
|
|
||||||
|
z[0] = OpType::postProcess(startingValue, length, extraParams);
|
||||||
for (int e = 0; e < maxThreads; e++)
|
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,23 +91,14 @@ namespace functions {
|
||||||
return execScalar<OpType>(x, xEws, length, extraParams);
|
return execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
X start = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
auto intermediate = new X[nd4j::math::nd4j_max<int>(1, omp_get_max_threads())];
|
|
||||||
for (int e = 0; e < omp_get_max_threads(); e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
for (auto i = 0; i < length; i++)
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
|
||||||
|
|
||||||
for (int e = 0; e < omp_get_max_threads(); e++)
|
return OpType::postProcess(startingValue, length, extraParams);
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
delete[] intermediate;
|
|
||||||
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,8 +130,8 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), REDUCE_BOOL_OPS);
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
|
@ -164,7 +144,7 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vresult);
|
auto z = reinterpret_cast<Z *>(vresult);
|
||||||
|
@ -176,7 +156,7 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < resultLength; i++)
|
for (uint i = 0; i < resultLength; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
return;
|
return;
|
||||||
|
@ -205,9 +185,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef INLINE_LOOPS
|
#ifdef INLINE_LOOPS
|
||||||
nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#else
|
#else
|
||||||
nd4j::ReductionBoolLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionBoolLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,49 +207,33 @@ namespace functions {
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
Z _CUDA_H ReduceBoolFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
Z intermediate[64];
|
||||||
|
|
||||||
auto startingVal = OpType::startingValue(x);
|
PRAGMA_OMP_SIMD
|
||||||
nd4j::OmpLaunchHelper info(length);
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
if (xEws == 1) {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
if (xEws == 1) {
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
|
||||||
auto local = OpType::startingValue(x);
|
} else {
|
||||||
auto threadNum = omp_get_thread_num();
|
for (auto i = start; i < stop; i++)
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
|
||||||
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
|
|
||||||
}
|
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
{
|
|
||||||
auto local = OpType::startingValue(x);
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
// merge results
|
||||||
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams);
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
// return result
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
return OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
|
||||||
}
|
|
||||||
return OpType::postProcess(startingVal, length, extraParams);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -59,9 +59,10 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < length; i++)
|
for (uint i = 0; i < length; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,25 +70,29 @@ namespace functions {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
X start = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
|
||||||
X intermediate[256];
|
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
Z intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
|
PRAGMA_OMP_SIMD
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++)
|
||||||
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
|
};
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
// merge results
|
||||||
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
|
||||||
|
|
||||||
|
// write out results
|
||||||
|
z[0] = OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,23 +110,14 @@ namespace functions {
|
||||||
return execScalar<OpType>(x, xEws, length, extraParams);
|
return execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
X start = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
auto intermediate = new X[nd4j::math::nd4j_max<int>(1, omp_get_max_threads())];
|
|
||||||
for (int e = 0; e < omp_get_max_threads(); e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
for (auto i = 0; i < length; i++)
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
|
||||||
|
|
||||||
for (int e = 0; e < omp_get_max_threads(); e++)
|
return OpType::postProcess(startingValue, length, extraParams);
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
delete[] intermediate;
|
|
||||||
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +149,7 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
extraParams,
|
extraParams,
|
||||||
|
@ -162,7 +158,7 @@ namespace functions {
|
||||||
dimension,
|
dimension,
|
||||||
dimensionLength,
|
dimensionLength,
|
||||||
tadShapeInfo,
|
tadShapeInfo,
|
||||||
tadOffset),
|
tadOffset, start, stop),
|
||||||
REDUCE_FLOAT_OPS);
|
REDUCE_FLOAT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,7 +172,7 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vresult);
|
auto z = reinterpret_cast<Z *>(vresult);
|
||||||
|
@ -188,7 +184,7 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
|
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < resultLength; i++)
|
for (uint i = 0; i < resultLength; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
return;
|
return;
|
||||||
|
@ -222,9 +218,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef INLINE_LOOPS
|
#ifdef INLINE_LOOPS
|
||||||
nd4j::ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#else
|
#else
|
||||||
nd4j::ReductionFloatLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionFloatLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,49 +241,34 @@ namespace functions {
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
Z _CUDA_H ReduceFloatFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
Z intermediate[64];
|
||||||
|
|
||||||
auto startingVal = OpType::startingValue(x);
|
PRAGMA_OMP_SIMD
|
||||||
nd4j::OmpLaunchHelper info(length);
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
int nt = info._numThreads;
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
if (xEws == 1) {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
if (xEws == 1) {
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
|
||||||
auto local = OpType::startingValue(x);
|
} else {
|
||||||
auto threadNum = omp_get_thread_num();
|
for (auto i = start; i < stop; i++)
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
|
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
{
|
|
||||||
auto local = OpType::startingValue(x);
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
// merge results
|
||||||
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams);
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
// return result
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
return OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return OpType::postProcess(startingVal, length, extraParams);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceFloatFunction, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
|
@ -55,7 +55,7 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < length; i++)
|
for (uint i = 0; i < length; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
return;
|
return;
|
||||||
|
@ -65,25 +65,29 @@ namespace functions {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
X start = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
|
||||||
X intermediate[256];
|
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
Z intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
|
PRAGMA_OMP_SIMD
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++)
|
||||||
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
|
};
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
// merge results
|
||||||
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
|
||||||
|
|
||||||
|
// write out results
|
||||||
|
z[0] = OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,23 +107,14 @@ namespace functions {
|
||||||
return execScalar<OpType>(x, xEws, length, extraParams);
|
return execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
X start = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
auto intermediate = new X[nd4j::math::nd4j_max<int>(1, omp_get_max_threads())];
|
|
||||||
for (int e = 0; e < omp_get_max_threads(); e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
for (auto i = 0; i < length; i++)
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
|
||||||
|
|
||||||
for (int e = 0; e < omp_get_max_threads(); e++)
|
return OpType::postProcess(startingValue, length, extraParams);
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
delete[] intermediate;
|
|
||||||
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,8 +147,8 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), REDUCE_LONG_OPS);
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_LONG_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
|
@ -166,7 +161,7 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vresult);
|
auto z = reinterpret_cast<Z *>(vresult);
|
||||||
|
@ -178,7 +173,7 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < resultLength; i++)
|
for (uint i = 0; i < resultLength; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
return;
|
return;
|
||||||
|
@ -212,9 +207,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef INLINE_LOOPS
|
#ifdef INLINE_LOOPS
|
||||||
nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#else
|
#else
|
||||||
nd4j::ReductionLongLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionLongLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,48 +230,34 @@ namespace functions {
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
Z intermediate[64];
|
||||||
|
|
||||||
auto startingVal = OpType::startingValue(x);
|
PRAGMA_OMP_SIMD
|
||||||
nd4j::OmpLaunchHelper info(length);
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
if (xEws == 1) {
|
if (xEws == 1) {
|
||||||
|
for (auto i = start; i < stop; i++)
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
|
||||||
{
|
} else {
|
||||||
auto local = OpType::startingValue(x);
|
for (auto i = start; i < stop; i++)
|
||||||
auto threadNum = omp_get_thread_num();
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
|
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
};
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
{
|
|
||||||
auto local = OpType::startingValue(x);
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
// merge results
|
||||||
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams);
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
// return result
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
return OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return OpType::postProcess(startingVal, length, extraParams);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceLongFunction, , LIBND4J_TYPES, LONG_TYPES);
|
||||||
|
|
|
@ -57,7 +57,7 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < length; i++)
|
for (uint i = 0; i < length; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
return;
|
return;
|
||||||
|
@ -67,25 +67,29 @@ namespace functions {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
X start = OpType::startingValue(x);
|
auto startingValue = OpType::startingValue(x);
|
||||||
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
|
||||||
X intermediate[256];
|
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
X intermediate[64];
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
|
PRAGMA_OMP_SIMD
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i++)
|
||||||
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
|
};
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
z[0] = OpType::postProcess(start, length, extraParams);
|
// merge results
|
||||||
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
|
||||||
|
|
||||||
|
// write out results
|
||||||
|
z[0] = OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,26 +107,15 @@ namespace functions {
|
||||||
|
|
||||||
if (xEws >= 1) {
|
if (xEws >= 1) {
|
||||||
return execScalar<OpType>(x, xEws, length, extraParams);
|
return execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
} else {
|
||||||
else {
|
auto startingValue = OpType::startingValue(x);
|
||||||
X start = OpType::startingValue(x);
|
|
||||||
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
|
||||||
X intermediate[256];
|
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
|
||||||
intermediate[e] = start;
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
|
for (auto i = 0; i < length; i++)
|
||||||
for(Nd4jLong i = 0; i < length; ++i)
|
startingValue = OpType::update(startingValue, OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
||||||
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
|
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
return OpType::postProcess(startingValue, length, extraParams);
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
|
||||||
|
|
||||||
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +147,7 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
||||||
xShapeInfo,
|
xShapeInfo,
|
||||||
extraParams,
|
extraParams,
|
||||||
|
@ -163,7 +156,7 @@ namespace functions {
|
||||||
dimension,
|
dimension,
|
||||||
dimensionLength,
|
dimensionLength,
|
||||||
tadShapeInfo,
|
tadShapeInfo,
|
||||||
tadOffset),
|
tadOffset, start, stop),
|
||||||
REDUCE_SAME_OPS);
|
REDUCE_SAME_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,7 +170,7 @@ namespace functions {
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffset) {
|
Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
|
@ -189,7 +182,7 @@ namespace functions {
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(zLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < zLength; i++)
|
for (uint i = 0; i < zLength; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
return;
|
return;
|
||||||
|
@ -223,9 +216,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef INLINE_LOOPS
|
#ifdef INLINE_LOOPS
|
||||||
nd4j::ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#else
|
#else
|
||||||
nd4j::ReductionSameLoops<X>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams);
|
nd4j::ReductionSameLoops<X>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,48 +239,34 @@ namespace functions {
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
X _CUDA_H ReduceSameFunction<X>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
X intermediate[64];
|
||||||
|
|
||||||
auto startingVal = OpType::startingValue(x);
|
PRAGMA_OMP_SIMD
|
||||||
nd4j::OmpLaunchHelper info(length);
|
for (auto e = 0; e < maxThreads; e++)
|
||||||
|
intermediate[e] = OpType::startingValue(x);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
if (xEws == 1) {
|
if (xEws == 1) {
|
||||||
|
for (auto i = start; i < stop; i++)
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], extraParams), extraParams);
|
||||||
{
|
} else {
|
||||||
auto local = OpType::startingValue(x);
|
for (auto i = start; i < stop; i++)
|
||||||
auto threadNum = omp_get_thread_num();
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i * xEws], extraParams), extraParams);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
|
||||||
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
|
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
};
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
{
|
|
||||||
auto local = OpType::startingValue(x);
|
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws*threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
// merge results
|
||||||
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams);
|
for (int e = 1; e < maxThreads; e++)
|
||||||
|
intermediate[0] = OpType::update(intermediate[0], intermediate[e], extraParams);
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
// return result
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
return OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return OpType::postProcess(startingVal, length, extraParams);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ReduceSameFunction, , LIBND4J_TYPES);
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
#include <Loops.h>
|
#include <Loops.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -51,72 +52,82 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
return;
|
return;
|
||||||
const auto startingVal = OpType::startingValue(x);
|
const auto startingVal = OpType::startingValue(x);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < length; i++)
|
for (uint i = 0; i < length; i++)
|
||||||
z[i] = startingVal;
|
z[i] = startingVal;
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
|
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
|
||||||
// it's possible case for EqualsWithEps op
|
|
||||||
if (extraParams != nullptr)
|
|
||||||
extraParamsVals[2] = extraParams[0];
|
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
Z startingVal = OpType::startingValue(x);
|
Z startingVal = OpType::startingValue(x);
|
||||||
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
int maxThreads = nd4j::math::nd4j_min<int>(64, nd4j::Environment::getInstance()->maxThreads());
|
||||||
nd4j::OmpLaunchHelper t(length, maxThreads);
|
Z intermediate[64];
|
||||||
Z intermediate[256];
|
Z extraParamsLocal[3 * 64];
|
||||||
Z extraParamsLocal[3 * 256];
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int e = 0; e < maxThreads; e++)
|
for (int e = 0; e < maxThreads; e++)
|
||||||
intermediate[e] = startingVal;
|
intermediate[e] = startingVal;
|
||||||
|
|
||||||
memset(extraParamsLocal, 0, 3 * 256 * sizeof(Z));
|
memset(extraParamsLocal, 0, 3 * 64 * sizeof(Z));
|
||||||
if (extraParams != nullptr)
|
if (extraParams != nullptr) {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int e = 0; e < maxThreads; e++)
|
// mostly for future reference
|
||||||
extraParamsLocal[3 * e + 2] = extraParams[0];
|
for (int e = 0; e < maxThreads; e++) {
|
||||||
|
extraParamsLocal[3 * e] = extraParams[0];
|
||||||
|
extraParamsLocal[3 * e + 1] = extraParams[1];
|
||||||
|
extraParamsLocal[3 * e + 2] = extraParams[2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo);
|
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, yShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for(unsigned int i = 0; i < length; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
const auto threadNum = omp_get_thread_num();
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[i], y[i], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||||
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[i], y[i], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum);
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
|
|
||||||
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
} else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for(unsigned int i = 0; i < length; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
const auto threadNum = omp_get_thread_num();
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||||
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum);
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
} else {
|
} else {
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
uint yShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads)
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for(unsigned int i = 0; i < length; i++) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
const auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
|
intermediate[thread_id] = OpType::update(intermediate[thread_id], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * thread_id), extraParamsLocal + 3 * thread_id);
|
||||||
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum);
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
maxThreads = samediff::Threads::parallel_for(func, 0, length, 1, maxThreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
// merge step
|
// merge step
|
||||||
for (int e = 0; e < maxThreads; e++)
|
for (int e = 0; e < maxThreads; e++)
|
||||||
OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e);
|
OpType::aggregateExtraParams(extraParamsVals, extraParamsLocal + 3 * e);
|
||||||
|
|
||||||
for (int e = 0; e < maxThreads; e++)
|
for (int e = 0; e < maxThreads; e++)
|
||||||
startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals);
|
startingVal = OpType::update(startingVal, intermediate[e], extraParamsVals);
|
||||||
|
|
||||||
|
// writing out result
|
||||||
z[0] = OpType::postProcess(startingVal, length, extraParamsVals);
|
z[0] = OpType::postProcess(startingVal, length, extraParamsVals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,7 +150,7 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
void *vy, Nd4jLong *yShapeInfo,
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
int *dimension, int dimensionLength) {
|
int *dimension, int dimensionLength, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X*>(vx);
|
auto x = reinterpret_cast<X*>(vx);
|
||||||
auto y = reinterpret_cast<X*>(vy);
|
auto y = reinterpret_cast<X*>(vy);
|
||||||
|
@ -151,9 +162,9 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#ifdef INLINE_LOOPS
|
#ifdef INLINE_LOOPS
|
||||||
nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams);
|
nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
|
||||||
#else
|
#else
|
||||||
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams);
|
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,16 +176,16 @@ void Reduce3<X,Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vy, Nd4jLong *yShapeInfo,
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
#ifdef INLINE_LOOPS
|
#ifdef INLINE_LOOPS
|
||||||
nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams);
|
nd4j::Reduction3Loops<X,Z>::template loopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
|
||||||
#else
|
#else
|
||||||
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams);
|
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,7 +199,7 @@ void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
||||||
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
|
@ -196,9 +207,9 @@ void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
|
||||||
auto extraParams = reinterpret_cast<Z*>(vextraParams);
|
auto extraParams = reinterpret_cast<Z*>(vextraParams);
|
||||||
|
|
||||||
#ifdef INLINE_LOOPS
|
#ifdef INLINE_LOOPS
|
||||||
nd4j::Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams);
|
nd4j::Reduction3Loops<X,Z>::template loopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop);
|
||||||
#else
|
#else
|
||||||
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams);
|
nd4j::Reduction3Loops<X,Z>::template innerloopReduce3All<OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,9 +220,9 @@ void Reduce3<X,Y>::exec( const int opNum,
|
||||||
void *extraParamsVals,
|
void *extraParamsVals,
|
||||||
void *vy, Nd4jLong *yShapeInfo,
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
int *dimension, int dimensionLength) {
|
int *dimension, int dimensionLength, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, start, stop), REDUCE3_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -223,9 +234,9 @@ void Reduce3<X,Y>::exec( const int opNum,
|
||||||
void *vy, Nd4jLong *yShapeInfo,
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(vx,xShapeInfo,extraParamsVals,vy, yShapeInfo,vz,zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), REDUCE3_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -238,9 +249,9 @@ void Reduce3<X,Y>::execAll(const int opNum,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
||||||
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets, int64_t start, int64_t stop) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), REDUCE3_OPS);
|
DISPATCH_BY_OPNUM_TT(execAll, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets, start, stop), REDUCE3_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
#include "../legacy_ops.h"
|
#include "../legacy_ops.h"
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
@ -39,7 +40,8 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vscalars,
|
void *vscalars,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
|
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
|
@ -63,29 +65,27 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads());
|
int num_threads = nd4j::math::nd4j_min<int>(numTads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
for (auto r = start; r < stop; r++) {
|
||||||
for (unsigned int r = 0; r < numTads; r++) {
|
|
||||||
auto oZ = z + zTadOffsets[r];
|
auto oZ = z + zTadOffsets[r];
|
||||||
auto oX = x + xTadOffsets[r];
|
auto oX = x + xTadOffsets[r];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
|
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
for (auto r = start; r < stop; r++) {
|
||||||
for (unsigned int r = 0; r < numTads; r++) {
|
|
||||||
auto oZ = z + zTadOffsets[r];
|
auto oZ = z + zTadOffsets[r];
|
||||||
auto oX = x + xTadOffsets[r];
|
auto oX = x + xTadOffsets[r];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
|
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,9 +98,10 @@ void ScalarTransform<X,Y,Z>::transform(int opNum,
|
||||||
void *scalars,
|
void *scalars,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
|
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_OPS);
|
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -110,9 +111,10 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
|
||||||
void *z, Nd4jLong zStride,
|
void *z, Nd4jLong zStride,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
const Nd4jLong n, bool allowParallelism) {
|
const uint64_t n,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, allowParallelism), SCALAR_OPS);
|
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, start, stop), SCALAR_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -121,9 +123,10 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
|
||||||
void *x, Nd4jLong *xShapeInfo,
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
void *z, Nd4jLong *zShapeInfo,
|
void *z, Nd4jLong *zShapeInfo,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams, bool allowParallelism) {
|
void *extraParams,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, allowParallelism), SCALAR_OPS);
|
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -132,7 +135,8 @@ template<typename OpType>
|
||||||
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams, bool allowParallelism) {
|
void *vextraParams,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
|
@ -146,48 +150,30 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
||||||
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, allowParallelism);
|
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, start, stop);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len, allowParallelism ? -1 : 1);
|
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
for (auto i = start; i < stop; i++) {
|
||||||
{
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadNum = omp_get_thread_num();
|
z[offset] = OpType::op(x[offset], scalar, extraParams);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
};
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -199,44 +185,22 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong xEws,
|
||||||
void *vz, Nd4jLong zEws,
|
void *vz, Nd4jLong zEws,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
const Nd4jLong len, bool allowParallelism) {
|
const uint64_t len, const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto scalar = reinterpret_cast<Y *>(vscalar)[0];
|
auto scalar = reinterpret_cast<Y *>(vscalar)[0];
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len, allowParallelism ? -1 : 1);
|
|
||||||
|
|
||||||
if (xEws == 1 && zEws == 1) {
|
if (xEws == 1 && zEws == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i] = OpType::op(x[i], scalar, extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i] = OpType::op(xi[i], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws * threadOffset;
|
|
||||||
auto zi = z + zEws * threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
#include "../legacy_ops.h"
|
#include "../legacy_ops.h"
|
||||||
|
|
||||||
|
@ -39,7 +40,8 @@ namespace functions {
|
||||||
void *vscalars,
|
void *vscalars,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
|
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
|
@ -64,29 +66,27 @@ namespace functions {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads());
|
int num_threads = nd4j::math::nd4j_min<int>(numTads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
for (auto r = start; r < stop; r++) {
|
||||||
for (unsigned int r = 0; r < numTads; r++) {
|
|
||||||
auto oZ = z + zTadOffsets[r];
|
auto oZ = z + zTadOffsets[r];
|
||||||
auto oX = x + xTadOffsets[r];
|
auto oX = x + xTadOffsets[r];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
|
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO
|
else {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
for (auto r = start; r < stop; r++) {
|
||||||
for (unsigned int r = 0; r < numTads; r++) {
|
|
||||||
auto oZ = z + zTadOffsets[r];
|
auto oZ = z + zTadOffsets[r];
|
||||||
auto oX = x + xTadOffsets[r];
|
auto oX = x + xTadOffsets[r];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
|
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,8 +103,8 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffsets) {
|
Nd4jLong *zTadOffsets, const uint64_t start, const uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_BOOL_OPS);
|
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,8 +116,9 @@ namespace functions {
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
const Nd4jLong n) {
|
const uint64_t n,
|
||||||
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_BOOL_OPS);
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
|
@ -127,8 +128,9 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams) {
|
void *extraParams,
|
||||||
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_BOOL_OPS);
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
|
@ -138,7 +140,8 @@ namespace functions {
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams) {
|
void *vextraParams,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
|
@ -149,53 +152,33 @@ namespace functions {
|
||||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
auto len = shape::length(xShapeInfo);
|
auto len = shape::length(xShapeInfo);
|
||||||
|
|
||||||
// nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws);
|
|
||||||
|
|
||||||
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len);
|
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, start, stop);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len);
|
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++) {
|
||||||
{
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadNum = omp_get_thread_num();
|
z[offset] = OpType::op(x[offset], scalar, extraParams);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
};
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,44 +191,23 @@ namespace functions {
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
const Nd4jLong len) {
|
const uint64_t len,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto scalar = reinterpret_cast<X *>(vscalar)[0];
|
auto scalar = reinterpret_cast<X *>(vscalar)[0];
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len);
|
|
||||||
|
|
||||||
if (xEws == 1 && zEws == 1) {
|
if (xEws == 1 && zEws == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i] = OpType::op(x[i], scalar, extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i] = OpType::op(xi[i], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws * threadOffset;
|
|
||||||
auto zi = z + zEws * threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <LoopKind.h>
|
#include <LoopKind.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
#include "../legacy_ops.h"
|
#include "../legacy_ops.h"
|
||||||
|
|
||||||
|
@ -39,7 +40,8 @@ namespace functions {
|
||||||
void *vscalars,
|
void *vscalars,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
|
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
|
@ -64,29 +66,27 @@ namespace functions {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads());
|
int num_threads = nd4j::math::nd4j_min<int>(numTads, nd4j::Environment::getInstance()->maxThreads());
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
for (auto r = start; r < stop; r++) {
|
||||||
for (unsigned int r = 0; r < numTads; r++) {
|
|
||||||
auto oZ = z + zTadOffsets[r];
|
auto oZ = z + zTadOffsets[r];
|
||||||
auto oX = x + xTadOffsets[r];
|
auto oX = x + xTadOffsets[r];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
|
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO
|
else {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
for (auto r = start; r < stop; r++) {
|
||||||
for (unsigned int r = 0; r < numTads; r++) {
|
|
||||||
auto oZ = z + zTadOffsets[r];
|
auto oZ = z + zTadOffsets[r];
|
||||||
auto oX = x + xTadOffsets[r];
|
auto oX = x + xTadOffsets[r];
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
|
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,8 +103,10 @@ namespace functions {
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffsets) {
|
Nd4jLong *zTadOffsets,
|
||||||
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_INT_OPS);
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
|
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets, start, stop), SCALAR_INT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,8 +118,9 @@ namespace functions {
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
const Nd4jLong n) {
|
const uint64_t n,
|
||||||
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_INT_OPS);
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n, start, stop), SCALAR_INT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
|
@ -127,8 +130,9 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams) {
|
void *extraParams,
|
||||||
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_INT_OPS);
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, start, stop), SCALAR_INT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
|
@ -138,7 +142,8 @@ namespace functions {
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams) {
|
void *vextraParams,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
|
@ -149,53 +154,33 @@ namespace functions {
|
||||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
auto len = shape::length(xShapeInfo);
|
auto len = shape::length(xShapeInfo);
|
||||||
|
|
||||||
// nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws);
|
|
||||||
|
|
||||||
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len);
|
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, start, stop);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len);
|
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++) {
|
||||||
{
|
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadNum = omp_get_thread_num();
|
z[offset] = OpType::op(x[offset], scalar, extraParams);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
};
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
z[offset] = OpType::op(x[offset], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_SIMD
|
||||||
{
|
for (auto i = start; i < stop; i++) {
|
||||||
auto threadNum = omp_get_thread_num();
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto zOffset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ);
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
|
||||||
|
};
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++) {
|
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,44 +193,23 @@ namespace functions {
|
||||||
Nd4jLong zEws,
|
Nd4jLong zEws,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
const Nd4jLong len) {
|
const uint64_t len,
|
||||||
|
const uint64_t start, const uint64_t stop) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
auto scalar = reinterpret_cast<X *>(vscalar)[0];
|
auto scalar = reinterpret_cast<X *>(vscalar)[0];
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len);
|
|
||||||
|
|
||||||
if (xEws == 1 && zEws == 1) {
|
if (xEws == 1 && zEws == 1) {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i] = OpType::op(x[i], scalar, extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i] = OpType::op(xi[i], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
for (auto i = start; i < stop; i++)
|
||||||
{
|
z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams);
|
||||||
auto threadNum = omp_get_thread_num();
|
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
|
||||||
auto xi = x + xEws * threadOffset;
|
|
||||||
auto zi = z + zEws * threadOffset;
|
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (unsigned int i = 0; i < ulen; i++)
|
|
||||||
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <helpers/shape.h>
|
#include <helpers/shape.h>
|
||||||
#include <helpers/TAD.h>
|
#include <helpers/TAD.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -90,8 +91,7 @@ namespace functions {
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCast = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
const bool canCast = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < length; i++) {
|
for (uint64_t i = 0; i < length; i++) {
|
||||||
|
|
||||||
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast);
|
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast);
|
||||||
|
|
||||||
SummaryStatsData<X> curr;
|
SummaryStatsData<X> curr;
|
||||||
|
@ -123,7 +123,7 @@ namespace functions {
|
||||||
return;
|
return;
|
||||||
SummaryStatsData<X> comp;
|
SummaryStatsData<X> comp;
|
||||||
comp.initWithValue(x[0]);
|
comp.initWithValue(x[0]);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for (uint i = 0; i < resultLength; i++)
|
for (uint i = 0; i < resultLength; i++)
|
||||||
z[i] = OpType::getValue(biasCorrected, comp);
|
z[i] = OpType::getValue(biasCorrected, comp);
|
||||||
return;
|
return;
|
||||||
|
@ -157,35 +157,37 @@ namespace functions {
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
|
const bool canCast = tadEWS == 1 && tadOrder == 'c' ? false : nd4j::DataTypeUtils::castShapeInfo<uint>(tadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (int r = 0; r < resultLength; r++) {
|
for (auto r = start; r < stop; r += increment) {
|
||||||
|
|
||||||
auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
|
auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
|
||||||
auto tx = x + tadOffsetForBlock;
|
auto tx = x + tadOffsetForBlock;
|
||||||
SummaryStatsData<X> comp;
|
SummaryStatsData <X> comp;
|
||||||
comp.initWithValue(tx[0]);
|
comp.initWithValue(tx[0]);
|
||||||
|
|
||||||
if (tadEWS == 1 && tadOrder == 'c') {
|
if (tadEWS == 1 && tadOrder == 'c') {
|
||||||
for (int i = 1; i < tadLength; i ++) {
|
for (int i = 1; i < tadLength; i++) {
|
||||||
SummaryStatsData <X> indexVal2;
|
SummaryStatsData <X> indexVal2;
|
||||||
indexVal2.initWithValue(tx[i]);
|
indexVal2.initWithValue(tx[i]);
|
||||||
|
|
||||||
comp = update(comp, OpType::op(indexVal2, extraParams), extraParams);
|
comp = update(comp, OpType::op(indexVal2, extraParams), extraParams);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 1; i < tadLength; i++) {
|
||||||
|
auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, canCast);
|
||||||
|
|
||||||
|
SummaryStatsData <X> indexVal2;
|
||||||
|
indexVal2.initWithValue(tx[xOffset]);
|
||||||
|
|
||||||
|
comp = update(comp, OpType::op(indexVal2, extraParams), extraParams);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
z[r] = OpType::getValue(biasCorrected, comp);
|
||||||
}
|
}
|
||||||
else {
|
};
|
||||||
for (int i = 1; i < tadLength; i ++) {
|
|
||||||
auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, canCast);
|
|
||||||
|
|
||||||
SummaryStatsData <X> indexVal2;
|
samediff::Threads::parallel_tad(func, 0, resultLength, 1);
|
||||||
indexVal2.initWithValue(tx[xOffset]);
|
|
||||||
|
|
||||||
comp = update(comp, OpType::op(indexVal2, extraParams), extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
z[r] = OpType::getValue(biasCorrected, comp);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,8 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo,
|
uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadOffsets, bool allowParallelism) {
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_ANY_OPS);
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets, allowParallelism), TRANSFORM_ANY_OPS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
|
@ -47,22 +46,13 @@ template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vz,Nd4jLong *zShapeInfo,
|
void *vz,Nd4jLong *zShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadShapeInfo,Nd4jLong *tadOffsets, bool allowParallelism) {
|
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
if(OpType::requiresSpecial) {
|
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
|
||||||
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (allowParallelism)
|
|
||||||
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
|
||||||
else
|
|
||||||
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, false>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,8 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo,
|
uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadOffsets) {
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_BOOL_OPS);
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_BOOL_OPS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
|
@ -49,20 +48,13 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
if(OpType::requiresSpecial) {
|
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
|
||||||
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
|
@ -36,9 +36,8 @@ namespace functions {
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo,
|
uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadOffsets) {
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_FLOAT_OPS);
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_FLOAT_OPS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
|
@ -48,20 +47,13 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
|
|
||||||
if(OpType::requiresSpecial) {
|
nd4j::TransformLoops<X,Z,Z>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
|
||||||
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
nd4j::TransformLoops<X,Z,Z>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
|
@ -36,10 +36,8 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams, uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadShapeInfo,
|
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_SAME_OPS);
|
||||||
Nd4jLong *tadOffsets) {
|
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_SAME_OPS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
@ -47,18 +45,14 @@ namespace functions {
|
||||||
void _CUDA_H TransformSame<X>::exec(void *vx, Nd4jLong *xShapeInfo,
|
void _CUDA_H TransformSame<X>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
uint64_t threadId, uint64_t numThreads) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
if(OpType::requiresSpecial) {
|
|
||||||
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
nd4j::TransformLoops<X,X,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
nd4j::TransformLoops<X,X,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformSame, , LIBND4J_TYPES);
|
||||||
|
|
|
@ -36,10 +36,8 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams, uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadShapeInfo,
|
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads), TRANSFORM_STRICT_OPS);
|
||||||
Nd4jLong *tadOffsets) {
|
|
||||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_STRICT_OPS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
@ -49,20 +47,13 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams, uint64_t threadId, uint64_t numThreads) {
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
if(OpType::requiresSpecial) {
|
nd4j::TransformLoops<X,X,X>::template loopTransform<OpType>(x, xShapeInfo, z, zShapeInfo, extraParams, threadId, numThreads);
|
||||||
OpType::execSpecial(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
nd4j::TransformLoops<X,X,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TransformStrict, , FLOAT_TYPES);
|
||||||
|
|
|
@ -1,145 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
// @author Yurii Shyrma, created on 27.11.2018
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "../aggregates.h"
|
|
||||||
|
|
||||||
namespace functions {
|
|
||||||
namespace aggregate {
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
template<typename OpClass>
|
|
||||||
__device__ void AggregatedFunction<X>::execCuda(X **arguments, int numArguments,
|
|
||||||
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
||||||
int *indexArguments, int numIndexArguments,
|
|
||||||
int **intArrays, int numIntArrays,
|
|
||||||
X *realArguments, int numRealArguments) {
|
|
||||||
|
|
||||||
OpClass::executeAggregateCuda(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
__device__ void AggregatedFunction<X>::execCuda(int opNum,
|
|
||||||
X **arguments, int numArguments,
|
|
||||||
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
||||||
int *indexArguments, int numIndexArguments,
|
|
||||||
int **intArrays, int numIntArrays,
|
|
||||||
X *realArguments, int numRealArguments) {
|
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_T(execCuda, PARAMS(arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), AGGREGATE_OPS);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
__global__ static void execAggregateKernel(int opNum,
|
|
||||||
void **varguments, int numArguments,
|
|
||||||
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
||||||
int *indexArguments, int numIndexArguments,
|
|
||||||
int **intArrays, int numIntArrays,
|
|
||||||
void *vrealArguments, int numRealArguments) {
|
|
||||||
|
|
||||||
auto arguments = reinterpret_cast<X**>(varguments);
|
|
||||||
auto realArguments = reinterpret_cast<X*>(vrealArguments);
|
|
||||||
functions::aggregate::AggregatedFunction<X>::execCuda(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
__host__ void AggregatedFunction<X>::aggregateKernelGeneric(dim3& launchDims, cudaStream_t *stream,
|
|
||||||
int opNum,
|
|
||||||
void **arguments, int numArguments,
|
|
||||||
Nd4jLong **shapeArguments, int numShapeArguments,
|
|
||||||
int *indexArguments, int numIndexArguments,
|
|
||||||
int **intArrays, int numIntArrays,
|
|
||||||
void *realArguments, int numRealArguments) {
|
|
||||||
|
|
||||||
execAggregateKernel<X><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments);
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "aggregateKernelGeneric(...) failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
__device__ void AggregatedFunction<X>::aggregateBatch(int opNum, int numAggregates,
|
|
||||||
int maxArgs, int maxShapes,
|
|
||||||
int maxIntArrays, int maxIntArraySize,
|
|
||||||
int maxIdx, int maxReals,
|
|
||||||
void *ptrToArguments) {
|
|
||||||
|
|
||||||
nd4j::PointersHelper<X> helper(ptrToArguments, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals);
|
|
||||||
|
|
||||||
// TODO: we probably should lift this restriction
|
|
||||||
__shared__ int *intArrays[32];
|
|
||||||
|
|
||||||
__shared__ X **arguments;
|
|
||||||
__shared__ Nd4jLong **shapes;
|
|
||||||
__shared__ int *idxArg;
|
|
||||||
__shared__ X *realArg;
|
|
||||||
|
|
||||||
for(int r = blockIdx.x; r < numAggregates; r += gridDim.x) {
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
arguments = helper.getArguments(r);
|
|
||||||
shapes = helper.getShapeArguments(r);
|
|
||||||
idxArg = helper.getIndexArguments(r);
|
|
||||||
realArg = helper.getRealArguments(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
// we fill intArrays param in parallel within block
|
|
||||||
if (threadIdx.x < 32 && threadIdx.x < maxIntArrays) {
|
|
||||||
intArrays[threadIdx.x] = helper.getIntArrayArguments(r, threadIdx.x);
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
functions::aggregate::AggregatedFunction<X>::execCuda(opNum, arguments, helper.getNumArguments(r), shapes, helper.getNumShapeArguments(r), idxArg, helper.getNumIndexArguments(r), intArrays, helper.getNumIntArrayArguments(r), realArg, helper.getNumRealArguments(r));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
__global__ static void execAggregateBatch(int opNum, int numAggregates,
|
|
||||||
int maxArgs, int maxShapes,
|
|
||||||
int maxIntArrays, int maxIntArraySize,
|
|
||||||
int maxIdx, int maxReals,
|
|
||||||
void *ptrToArguments) {
|
|
||||||
|
|
||||||
functions::aggregate::AggregatedFunction<X>::aggregateBatch(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename X>
|
|
||||||
__host__ void AggregatedFunction<X>::aggregateBatchKernelGeneric(dim3& launchDims, cudaStream_t *stream,
|
|
||||||
int opNum, int numAggregates,
|
|
||||||
int maxArgs, int maxShapes,
|
|
||||||
int maxIntArrays, int maxIntArraySize,
|
|
||||||
int maxIdx, int maxReals,
|
|
||||||
void *ptrToArguments) {
|
|
||||||
|
|
||||||
execAggregateBatch<X><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments);
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "aggregateBatchKernel(...) failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class AggregatedFunction, , FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -32,84 +32,6 @@
|
||||||
|
|
||||||
namespace functions {
|
namespace functions {
|
||||||
namespace broadcast {
|
namespace broadcast {
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
void Broadcast<X, Y, Z>::execInverse(int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
//
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
void Broadcast<X, Y, Z>::exec(int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* CPU execution
|
|
||||||
* @param x the input
|
|
||||||
* @param xShapeInfo the x shape information
|
|
||||||
* @param y the y data
|
|
||||||
* @param yShapeInfo the y shape information
|
|
||||||
* @param result the result
|
|
||||||
* @param resultShapeInfo the result shape information
|
|
||||||
* @param dimension the dimension to broadcast along long
|
|
||||||
* @param dimensionLength the length of the dimension buffer
|
|
||||||
*/
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
template<typename OpType>
|
|
||||||
void Broadcast<X, Y, Z>::exec(void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
//
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
template<typename OpType>
|
|
||||||
void Broadcast<X, Y, Z>::execInverse(void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -224,76 +224,6 @@ namespace functions {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
void BroadcastBool<X,Y>::exec(int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
void BroadcastBool<X,Y>::execInverse(int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void BroadcastBool<X,Y>::exec(void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void BroadcastBool<X,Y>::execInverse(void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -217,75 +217,6 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
void BroadcastInt<X>::exec(int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
void BroadcastInt<X>::execInverse(int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
template<typename OpType>
|
|
||||||
void BroadcastInt<X>::exec(void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
template<typename OpType>
|
|
||||||
void BroadcastInt<X>::execInverse(void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeInfo,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo,
|
|
||||||
Nd4jLong *tadOffset,
|
|
||||||
Nd4jLong *tadShapeInfoZ,
|
|
||||||
Nd4jLong *tadOffsetZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -359,32 +359,6 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
|
||||||
Nd4jLong IndexReduce<X,Z>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
|
||||||
void IndexReduce<X,Z>::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
|
||||||
template<typename OpType>
|
|
||||||
Nd4jLong IndexReduce<X,Z>:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
|
||||||
template<typename OpType>
|
|
||||||
_CUDA_H void IndexReduce<X,Z>::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,58 +22,6 @@
|
||||||
|
|
||||||
namespace functions {
|
namespace functions {
|
||||||
namespace pairwise_transforms {
|
namespace pairwise_transforms {
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
void PairWiseTransform<X, Y, Z>::exec(
|
|
||||||
const int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeInfo,
|
|
||||||
void *z,
|
|
||||||
Nd4jLong *zShapeInfo,
|
|
||||||
void *extraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
void PairWiseTransform<X, Y, Z>::exec(
|
|
||||||
const int opNum,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong xStride,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong yStride,
|
|
||||||
void *z,
|
|
||||||
Nd4jLong resultStride,
|
|
||||||
void *extraParams,
|
|
||||||
Nd4jLong len) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
template<typename OpType>
|
|
||||||
void PairWiseTransform<X, Y, Z>:: exec(
|
|
||||||
void *vx,
|
|
||||||
Nd4jLong* xShapeInfo,
|
|
||||||
void *vy,
|
|
||||||
Nd4jLong* yShapeInfo,
|
|
||||||
void *vresult,
|
|
||||||
Nd4jLong* zShapeInfo,
|
|
||||||
void *vextraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
|
||||||
template<typename OpType>
|
|
||||||
void PairWiseTransform<X, Y, Z>::exec(void *vx,
|
|
||||||
Nd4jLong xStride,
|
|
||||||
void *vy,
|
|
||||||
Nd4jLong yStride,
|
|
||||||
void *vresult,
|
|
||||||
Nd4jLong resultStride,
|
|
||||||
void *vextraParams,
|
|
||||||
const Nd4jLong len) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -110,63 +110,6 @@ void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS);
|
DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
void PairWiseBoolTransform<X,Y>::exec(
|
|
||||||
const int opNum,
|
|
||||||
void *dx,
|
|
||||||
Nd4jLong *xShapeBuffer,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeBuffer,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeBuffer,
|
|
||||||
void *extraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
void PairWiseBoolTransform<X,Y>::exec(
|
|
||||||
const int opNum,
|
|
||||||
void *dx,
|
|
||||||
Nd4jLong xStride,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong yStride,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong resultStride,
|
|
||||||
void *extraParams,
|
|
||||||
Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void PairWiseBoolTransform<X,Y>::exec(
|
|
||||||
void *vx,
|
|
||||||
Nd4jLong* xShapeBuffer,
|
|
||||||
void *vy,
|
|
||||||
Nd4jLong* yShapeBuffer,
|
|
||||||
void *vresult,
|
|
||||||
Nd4jLong* resultShapeBuffer,
|
|
||||||
void *vextraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void PairWiseBoolTransform<X,Y>::exec(void *vx,
|
|
||||||
Nd4jLong xStride,
|
|
||||||
void *vy,
|
|
||||||
Nd4jLong yStride,
|
|
||||||
void *vresult,
|
|
||||||
Nd4jLong resultStride,
|
|
||||||
void *vextraParams,
|
|
||||||
const Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -109,63 +109,6 @@ void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *
|
||||||
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS);
|
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
void PairWiseIntTransform<X>::exec(
|
|
||||||
const int opNum,
|
|
||||||
void *dx,
|
|
||||||
Nd4jLong *xShapeBuffer,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong *yShapeBuffer,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong *resultShapeBuffer,
|
|
||||||
void *extraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
void PairWiseIntTransform<X>::exec(
|
|
||||||
const int opNum,
|
|
||||||
void *dx,
|
|
||||||
Nd4jLong xStride,
|
|
||||||
void *y,
|
|
||||||
Nd4jLong yStride,
|
|
||||||
void *result,
|
|
||||||
Nd4jLong resultStride,
|
|
||||||
void *extraParams,
|
|
||||||
Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
template<typename OpType>
|
|
||||||
void PairWiseIntTransform<X>::exec(
|
|
||||||
void *vx,
|
|
||||||
Nd4jLong* xShapeBuffer,
|
|
||||||
void *vy,
|
|
||||||
Nd4jLong* yShapeBuffer,
|
|
||||||
void *vresult,
|
|
||||||
Nd4jLong* resultShapeBuffer,
|
|
||||||
void *vextraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
template<typename OpType>
|
|
||||||
void PairWiseIntTransform<X>::exec(void *vx,
|
|
||||||
Nd4jLong xStride,
|
|
||||||
void *vy,
|
|
||||||
Nd4jLong yStride,
|
|
||||||
void *vresult,
|
|
||||||
Nd4jLong resultStride,
|
|
||||||
void *vextraParams,
|
|
||||||
const Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -442,39 +442,6 @@ namespace functions {
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
template<typename OpClass>
|
|
||||||
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
template<typename OpClass>
|
|
||||||
void RandomFunction<T>::execTransform(Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
template<typename OpClass>
|
|
||||||
void RandomFunction<T>::execTransform(Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *x, Nd4jLong *xShapeBuffer, void *y, Nd4jLong *yShapeBuffer, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
void RandomFunction<T>::execTransform(int opNum, Nd4jPointer state, void *z, Nd4jLong *zShapeBuffer, void *extraArguments) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,7 +132,7 @@ __device__ void Reduce3<X,Z>::execScalarCuda( void *vx, Nd4jLong *xShapeInfo,
|
||||||
extraZ[1] = (Z) 0.0f;
|
extraZ[1] = (Z) 0.0f;
|
||||||
|
|
||||||
if (extraParams != nullptr)
|
if (extraParams != nullptr)
|
||||||
extraZ[2] = *(static_cast<Z*>(extraParams));
|
extraZ[2] = static_cast<Z*>(extraParams)[2];
|
||||||
else
|
else
|
||||||
extraZ[2] = (Z) 0.0f;
|
extraZ[2] = (Z) 0.0f;
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,56 +27,7 @@
|
||||||
|
|
||||||
namespace functions {
|
namespace functions {
|
||||||
namespace reduce3 {
|
namespace reduce3 {
|
||||||
template <typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void Reduce3<X,Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void Reduce3<X,Y>::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParamsVals, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void Reduce3<X,Y>::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void Reduce3<X,Y>::execAll(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void Reduce3<X,Y>::exec(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void Reduce3<X,Y>::execAll(const int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -231,41 +231,6 @@ void ScalarBoolTransform<X,Y>::executeCudaAlongDimension(dim3& launchDims, cudaS
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ScalarBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
template <typename OpType>
|
|
||||||
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
void ScalarBoolTransform<X,Y>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
void ScalarBoolTransform<X,Y>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void ScalarBoolTransform<X,Y>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -230,40 +230,6 @@ void ScalarIntTransform<X>::executeCudaAlongDimension(dim3& launchDims, cudaStre
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);
|
||||||
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
template <typename OpType>
|
|
||||||
void ScalarIntTransform<X,>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
void ScalarIntTransform<X>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
template<typename OpType>
|
|
||||||
void ScalarIntTransform<X>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename X>
|
|
||||||
template<typename OpType>
|
|
||||||
void ScalarIntTransform<X>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -414,73 +414,6 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
Y SummaryStatsReduce<X,Y>::execScalar(int opNum,
|
|
||||||
bool biasCorrected,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void SummaryStatsReduce<X,Y>::execScalar(int opNum,
|
|
||||||
bool biasCorrected,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams,
|
|
||||||
void *vz,
|
|
||||||
Nd4jLong *resultShapeInfoBuffer) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void SummaryStatsReduce<X,Y>::exec(int opNum,
|
|
||||||
bool biasCorrected,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams,
|
|
||||||
void *vz,
|
|
||||||
Nd4jLong *resultShapeInfoBuffer,
|
|
||||||
int *dimension, int dimensionLength) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
Y SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void SummaryStatsReduce<X,Y>::execScalar(bool biasCorrected,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams,
|
|
||||||
void *vz,
|
|
||||||
Nd4jLong *resultShapeInfoBuffer) {
|
|
||||||
//
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
template<typename OpType>
|
|
||||||
void SummaryStatsReduce<X,Y>::exec(bool biasCorrected,
|
|
||||||
void *x,
|
|
||||||
Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams,
|
|
||||||
void *vz,
|
|
||||||
Nd4jLong *resultShapeInfoBuffer,
|
|
||||||
int *dimension,
|
|
||||||
int dimensionLength) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT SummaryStatsReduce, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -114,17 +114,6 @@ namespace functions {
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "transformAny(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "transformAny(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
|
||||||
void TransformAny<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Z>
|
|
||||||
template <typename OpType>
|
|
||||||
void TransformAny<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformAny, , LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,17 +120,6 @@ namespace functions {
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "transformBool(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "transformBool(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
|
||||||
void TransformBool<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Z>
|
|
||||||
template <typename OpType>
|
|
||||||
void TransformBool<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformBool, , LIBND4J_TYPES, BOOL_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -142,18 +142,6 @@ namespace functions {
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "transformFloat(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Z>
|
|
||||||
void TransformFloat<X,Z>::exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename X, typename Z>
|
|
||||||
template <typename OpType>
|
|
||||||
void TransformFloat<X,Z>::exec(void *dx, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TransformFloat, , LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue