* range check for scalar_int

Signed-off-by: raver119 <raver119@gmail.com>

* no simd

Signed-off-by: raver119 <raver119@gmail.com>

* no ops

Signed-off-by: raver119 <raver119@gmail.com>

* cyclic shift?

Signed-off-by: raver119 <raver119@gmail.com>

* left split

Signed-off-by: raver119 <raver119@gmail.com>

* left split

Signed-off-by: raver119 <raver119@gmail.com>

* rot ops unrolled templates

Signed-off-by: raver119 <raver119@gmail.com>

* no rotl/rotr for uint64

Signed-off-by: raver119 <raver119@gmail.com>

* no rotl/rotr for uint64 2

Signed-off-by: raver119 <raver119@gmail.com>

* no rotl/rotr for uint64 3

Signed-off-by: raver119 <raver119@gmail.com>

* ARM_BUILD declared

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-21 14:31:00 +03:00 committed by GitHub
parent d19dbb955c
commit e78be14cc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 151 additions and 13 deletions

View File

@ -34,6 +34,11 @@ if (APPLE_BUILD)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DAPPLE_BUILD=true -mmacosx-version-min=10.10")
endif()
if (ARM_BUILD)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DARM_BUILD=true")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DARM_BUILD=true")
endif()
if (ANDROID_BUILD)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DANDROID_BUILD=true")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DANDROID_BUILD=true")

View File

@ -173,7 +173,7 @@ fi
case "$OS" in
linux-armhf)
export RPI_BIN=$RPI_HOME/tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin/arm-linux-gnueabihf
export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake"
export CMAKE_COMMAND="$CMAKE_COMMAND -D CMAKE_TOOLCHAIN_FILE=cmake/rpi.cmake -DARM_BUILD=true"
if [ -z "$ARCH" ]; then
ARCH="armv7-r"
fi
@ -183,6 +183,7 @@ case "$OS" in
if [ -z "$ARCH" ]; then
ARCH="armv8-a"
fi
export CMAKE_COMMAND="$CMAKE_COMMAND -DARM_BUILD=true"
;;
android-arm)

View File

@ -201,17 +201,16 @@ namespace functions {
auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams);
if (scalar < (sizeof(X) * 8)) {
if (xEws == 1 && zEws == 1) {
PRAGMA_OMP_SIMD
for (auto i = start; i < stop; i++)
z[i] = OpType::op(x[i], scalar, extraParams);
}
else {
PRAGMA_OMP_SIMD
} else {
for (auto i = start; i < stop; i++)
z[i * zEws] = OpType::op(x[i * xEws], scalar, extraParams);
}
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);

View File

@ -186,7 +186,7 @@
(1, SummaryStatsStandardDeviation)
#define SCALAR_INT_OPS \
(0, ShiftLeft), \
(0, ShiftLeft) ,\
(1, ShiftRight), \
(2, CyclicShiftLeft), \
(3, CyclicShiftRight), \

View File

@ -722,7 +722,7 @@ namespace simdOps {
public:
op_def static X op(X d1, X d2) {
return d1 << d2 | d1 >> ((sizeof(X) * 8) - d2);
return nd4j::math::nd4j_rotl<X>(d1, d2);
}
op_def static X op(X d1, X d2, X *params) {
@ -735,7 +735,7 @@ namespace simdOps {
public:
op_def static X op(X d1, X d2) {
return d1 >> d2 | d1 << ((sizeof(X) * 8) - d2);
return nd4j::math::nd4j_rotr<X>(d1, d2);
}
op_def static X op(X d1, X d2, X *params) {

View File

@ -80,6 +80,7 @@ union PAIR {
#else
#define math_def
#include <types/float16.h>
#endif
@ -145,6 +146,12 @@ namespace nd4j {
template <typename T>
math_def FORCEINLINE T p_rint(T value);
template <typename T>
math_def FORCEINLINE T p_rotl(T value, T shift);
template <typename T>
math_def FORCEINLINE T p_rotr(T value, T shift);
template <typename T>
math_def FORCEINLINE T p_remainder(T val1, T val2);
@ -751,6 +758,116 @@ namespace nd4j {
math_def FORCEINLINE double p_atanh(double value) {
return atanh(value);
}
/////////
template <typename T>
math_def FORCEINLINE T _rotate_left(T value, T shift);
template <typename T>
math_def FORCEINLINE T _rotate_right(T value, T shift);
template <>
math_def FORCEINLINE int8_t _rotate_left(int8_t value, int8_t shift) {
return value << shift | value >> (8 - shift);
}
template <>
math_def FORCEINLINE int8_t _rotate_right(int8_t value, int8_t shift) {
return value >> shift | value << (8 - shift);
}
template <>
math_def FORCEINLINE uint8_t _rotate_left(uint8_t value, uint8_t shift) {
return value << shift | value >> (8 - shift);
}
template <>
math_def FORCEINLINE uint8_t _rotate_right(uint8_t value, uint8_t shift) {
return value >> shift | value << (8 - shift);
}
template <>
math_def FORCEINLINE int16_t _rotate_left(int16_t value, int16_t shift) {
return value << shift | value >> (16 - shift);
}
template <>
math_def FORCEINLINE int16_t _rotate_right(int16_t value, int16_t shift) {
return value >> shift | value << (16 - shift);
}
template <>
math_def FORCEINLINE uint16_t _rotate_left(uint16_t value, uint16_t shift) {
return value << shift | value >> (16 - shift);
}
template <>
math_def FORCEINLINE uint16_t _rotate_right(uint16_t value, uint16_t shift) {
return value >> shift | value << (16 - shift);
}
template <>
math_def FORCEINLINE int _rotate_left(int value, int shift) {
return value << shift | value >> (32 - shift);
}
template <>
math_def FORCEINLINE int _rotate_right(int value, int shift) {
return value >> shift | value << (32 - shift);
}
template <>
math_def FORCEINLINE uint32_t _rotate_left(uint32_t value, uint32_t shift) {
return value << shift | value >> (32 - shift);
}
template <>
math_def FORCEINLINE uint32_t _rotate_right(uint32_t value, uint32_t shift) {
return value >> shift | value << (32 - shift);
}
template <>
math_def FORCEINLINE Nd4jLong _rotate_left(Nd4jLong value, Nd4jLong shift) {
return value << shift | value >> (64 - shift);
}
template <>
math_def FORCEINLINE Nd4jLong _rotate_right(Nd4jLong value, Nd4jLong shift) {
return value >> shift | value << (64 - shift);
}
template <>
math_def FORCEINLINE uint64_t _rotate_left(uint64_t value, uint64_t shift) {
#ifdef ARM_BUILD
// TODO: eventually remove this once gcc fixes the bug
Nd4jLong val = _rotate_left<Nd4jLong>(*reinterpret_cast<Nd4jLong *>(&value), *reinterpret_cast<Nd4jLong *>(&shift));
return *reinterpret_cast<uint64_t *>(&val);
#else
return value << shift | value >> (64 - shift);
#endif
}
template <>
math_def FORCEINLINE uint64_t _rotate_right(uint64_t value, uint64_t shift) {
#ifdef ARM_BUILD
// TODO: eventually remove this once gcc fixes the bug
Nd4jLong val = _rotate_right<Nd4jLong>(*reinterpret_cast<Nd4jLong *>(&value), *reinterpret_cast<Nd4jLong *>(&shift));
return *reinterpret_cast<uint64_t *>(&val);
#else
return value >> shift | value << (64 - shift);
#endif
}
template <typename T>
math_def FORCEINLINE T p_rotl(T value, T shift) {
return _rotate_left<T>(value, shift);
}
template <typename T>
math_def FORCEINLINE T p_rotr(T value, T shift) {
return _rotate_right<T>(value, shift);
}
}
}

View File

@ -77,6 +77,12 @@ namespace nd4j {
template <typename T, typename Z>
math_def inline Z nd4j_softplus(T val);
template <typename T>
math_def inline T nd4j_rotl(T val, T shift);
template <typename T>
math_def inline T nd4j_rotr(T val, T shift);
//#ifndef __CUDACC__
template<typename X, typename Y, typename Z>
math_def inline Z nd4j_dot(X *x, Y *y, int length);
@ -817,6 +823,16 @@ namespace nd4j {
return val <= 0 ? neg_tanh(val) : pos_tanh(val);
}
template <typename T>
math_def inline T nd4j_rotl(T val, T shift) {
return p_rotl<T>(val, shift);
}
template <typename T>
math_def inline T nd4j_rotr(T val, T shift) {
return p_rotr<T>(val, shift);
}
template <typename X, typename Z>
math_def inline Z nd4j_erf(X val) {
return p_erf<Z>(static_cast<Z>(val));