diff --git a/libnd4j/include/platformmath.h b/libnd4j/include/platformmath.h index b7cbe3745..b58e8f7f6 100644 --- a/libnd4j/include/platformmath.h +++ b/libnd4j/include/platformmath.h @@ -326,6 +326,11 @@ namespace nd4j { #endif } + template <> + math_def FORCEINLINE bfloat16 p_floor(bfloat16 value) { + return static_cast(floorf((float)value)); + } + template <> math_def FORCEINLINE double p_floor(double value) { return floor(value); @@ -352,6 +357,11 @@ namespace nd4j { #endif } + template <> + math_def FORCEINLINE bfloat16 p_ceil(bfloat16 value) { + return static_cast(ceilf((float)value)); + } + template <> math_def FORCEINLINE double p_ceil(double value) { return ceil(value); @@ -374,6 +384,12 @@ namespace nd4j { return static_cast(roundf((float) val)); } + template <> + math_def FORCEINLINE bfloat16 p_round(bfloat16 value) { + return static_cast(roundf((float)value)); + } + + template <> math_def FORCEINLINE double p_round(double value) { return round(value);