diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 51a29e522..2a12d5b9c 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -34,6 +34,11 @@ if (APPLE_BUILD) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10") endif() +if (ARM_BUILD) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DARM_BUILD=true") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DARM_BUILD=true") +endif() + if (ANDROID_BUILD) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DANDROID_BUILD=true") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DANDROID_BUILD=true") diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index c07756d8c..ae3fac13a 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -173,7 +173,7 @@ fi case "$OS" in linux-armhf) export RPI_BIN=$RPI_HOME/tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin/arm-linux-gnueabihf - export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake" + export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake -DARM_BUILD=true" if [ -z "$ARCH" ]; then ARCH="armv7-r" fi @@ -183,6 +183,7 @@ case "$OS" in if [ -z "$ARCH" ]; then ARCH="armv8-a" fi + export CMAKE_COMMAND="$CMAKE_COMMAND -DARM_BUILD=true" ;; android-arm) diff --git a/libnd4j/include/loops/cpu/scalar_int.cpp b/libnd4j/include/loops/cpu/scalar_int.cpp index 5f2308418..e9f96ff70 100644 --- a/libnd4j/include/loops/cpu/scalar_int.cpp +++ b/libnd4j/include/loops/cpu/scalar_int.cpp @@ -201,15 +201,14 @@ namespace functions { auto scalar = reinterpret_cast(vscalar)[0]; auto extraParams = reinterpret_cast(vextraParams); - if (xEws == 1 && zEws == 1) { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i] = OpType::op(x[i], scalar, extraParams); - } - else { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i++) - z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + if (scalar < (sizeof(X) * 8)) { + if (xEws == 1 && zEws == 1) { + for (auto i = start; i < stop; i++) + z[i] = OpType::op(x[i], scalar, extraParams); + } else { + for (auto i = start; i < stop; i++) + z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams); + } } } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index ea32b154c..95f83be1a 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -186,7 +186,7 @@ (1, SummaryStatsStandardDeviation) #define SCALAR_INT_OPS \ - (0, ShiftLeft), \ + (0, ShiftLeft) ,\ (1, ShiftRight), \ (2, CyclicShiftLeft), \ (3, CyclicShiftRight), \ diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index 1b54889d4..ff3ee570b 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -722,7 +722,7 @@ namespace simdOps { public: op_def static X op(X d1, X d2) { - return d1 << d2 | d1 >> ((sizeof(X) * 8) - d2); + return nd4j::math::nd4j_rotl(d1, d2); } op_def static X op(X d1, X d2, X *params) { @@ -735,7 +735,7 @@ namespace simdOps { public: op_def static X op(X d1, X d2) { - return d1 >> d2 | d1 << ((sizeof(X) * 8) - d2); + return nd4j::math::nd4j_rotr(d1, d2); } op_def static X op(X d1, X d2, X *params) { diff --git a/libnd4j/include/platformmath.h b/libnd4j/include/platformmath.h index b58e8f7f6..5c0a1d07f 100644 --- a/libnd4j/include/platformmath.h +++ b/libnd4j/include/platformmath.h @@ -80,6 +80,7 @@ union PAIR { #else #define math_def #include + #endif @@ -145,6 +146,12 @@ namespace nd4j { template math_def FORCEINLINE T p_rint(T value); + template + math_def FORCEINLINE T p_rotl(T value, T shift); + + template + math_def FORCEINLINE T p_rotr(T value, T shift); + template math_def FORCEINLINE T p_remainder(T val1, T val2); @@ -751,6 +758,116 @@ namespace nd4j { math_def FORCEINLINE double p_atanh(double value) { return atanh(value); } + +///////// + template + math_def FORCEINLINE T _rotate_left(T value, T shift); + + template + math_def FORCEINLINE T _rotate_right(T value, T shift); + + template <> + math_def FORCEINLINE int8_t _rotate_left(int8_t value, int8_t shift) { + return value << shift | value >> (8 - shift); + } + + template <> + math_def FORCEINLINE int8_t _rotate_right(int8_t value, int8_t shift) { + return value >> shift | value << (8 - shift); + } + + template <> + math_def FORCEINLINE uint8_t _rotate_left(uint8_t value, uint8_t shift) { + return value << shift | value >> (8 - shift); + } + + template <> + math_def FORCEINLINE uint8_t _rotate_right(uint8_t value, uint8_t shift) { + return value >> shift | value << (8 - shift); + } + + template <> + math_def FORCEINLINE int16_t _rotate_left(int16_t value, int16_t shift) { + return value << shift | value >> (16 - shift); + } + + template <> + math_def FORCEINLINE int16_t _rotate_right(int16_t value, int16_t shift) { + return value >> shift | value << (16 - shift); + } + + template <> + math_def FORCEINLINE uint16_t _rotate_left(uint16_t value, uint16_t shift) { + return value << shift | value >> (16 - shift); + } + + template <> + math_def FORCEINLINE uint16_t _rotate_right(uint16_t value, uint16_t shift) { + return value >> shift | value << (16 - shift); + } + + template <> + math_def FORCEINLINE int _rotate_left(int value, int shift) { + return value << shift | value >> (32 - shift); + } + + template <> + math_def FORCEINLINE int _rotate_right(int value, int shift) { + return value >> shift | value << (32 - shift); + } + + template <> + math_def FORCEINLINE uint32_t _rotate_left(uint32_t value, uint32_t shift) { + return value << shift | value >> (32 - shift); + } + + template <> + math_def FORCEINLINE uint32_t _rotate_right(uint32_t value, uint32_t shift) { + return value >> shift | value << (32 - shift); + } + + template <> + math_def FORCEINLINE Nd4jLong _rotate_left(Nd4jLong value, Nd4jLong shift) { + return value << shift | value >> (64 - shift); + } + + template <> + math_def FORCEINLINE Nd4jLong _rotate_right(Nd4jLong value, Nd4jLong shift) { + return value >> shift | value << (64 - shift); + } + + template <> + math_def FORCEINLINE uint64_t _rotate_left(uint64_t value, uint64_t shift) { +#ifdef ARM_BUILD + // TODO: eventually remove this once gcc fixes the bug + Nd4jLong val = _rotate_left(*reinterpret_cast(&value), *reinterpret_cast(&shift)); + return *reinterpret_cast(&val); +#else + return value << shift | value >> (64 - shift); +#endif + } + + template <> + math_def FORCEINLINE uint64_t _rotate_right(uint64_t value, uint64_t shift) { +#ifdef ARM_BUILD + // TODO: eventually remove this once gcc fixes the bug + Nd4jLong val = _rotate_right(*reinterpret_cast(&value), *reinterpret_cast(&shift)); + return *reinterpret_cast(&val); +#else + return value >> shift | value << (64 - shift); +#endif + } + + + template + math_def FORCEINLINE T p_rotl(T value, T shift) { + return _rotate_left(value, shift); + } + + template + math_def FORCEINLINE T p_rotr(T value, T shift) { + return _rotate_right(value, shift); + } } } diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 48021d734..6163488e3 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -77,6 +77,12 @@ namespace nd4j { template math_def inline Z nd4j_softplus(T val); + template + math_def inline T nd4j_rotl(T val, T shift); + + template + math_def inline T nd4j_rotr(T val, T shift); + //#ifndef __CUDACC__ template math_def inline Z nd4j_dot(X *x, Y *y, int length); @@ -817,6 +823,16 @@ namespace nd4j { return val <= 0 ? neg_tanh(val) : pos_tanh(val); } + template + math_def inline T nd4j_rotl(T val, T shift) { + return p_rotl(val, shift); + } + + template + math_def inline T nd4j_rotr(T val, T shift) { + return p_rotr(val, shift); + } + template math_def inline Z nd4j_erf(X val) { return p_erf(static_cast(val));