diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index a0529d106..726549415 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -344,7 +344,7 @@ bool NDArray::isS() const { ////////////////////////////////////////////////////////////////////////// bool NDArray::isR() const { auto xType = ArrayOptions::dataType(this->_shapeInfo); - return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8; + return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16; } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/loops/cpu/random.cpp b/libnd4j/include/loops/cpu/random.cpp index aeeedc007..889e48181 100644 --- a/libnd4j/include/loops/cpu/random.cpp +++ b/libnd4j/include/loops/cpu/random.cpp @@ -276,23 +276,6 @@ namespace functions { DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS) } - // FIXME: eventually we might want to get rid of that -#ifndef __CLION_IDE__ -/* - BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *y, Nd4jLong *yShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *y, Nd4jLong *yShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *y, Nd4jLong *yShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS) - - BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS) - - BUILD_CALL_1(template void RandomFunction::execTransform, float, (Nd4jPointer state, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, float16, (Nd4jPointer state, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS) - BUILD_CALL_1(template void RandomFunction::execTransform, double, (Nd4jPointer state, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS) -*/ -#endif - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES); } diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 3a4552790..501a29b8c 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -27,6 +27,7 @@ #include #include #include +#include using namespace nd4j; @@ -104,6 +105,14 @@ TEST_F(DataTypesValidationTests, Basic_Test_4) { ASSERT_EQ(ND4J_STATUS_VALIDATION, result); } +TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) { + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6); + + ASSERT_TRUE(x.sumNumber().e(0) > 0); +} + TEST_F(DataTypesValidationTests, cast_1) { float16 x = static_cast(1.f); diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 9f6f3d787..b646bacab 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1212,6 +1212,18 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) { ASSERT_EQ(e, z); } +TEST_F(JavaInteropTests, test_bfloat16_rng) { + if (!Environment::getInstance()->isCPU()) + return; + + auto z = NDArrayFactory::create('c', {10}); + RandomGenerator rng(119, 323841120L); + bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; + execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args); + z.printIndexedBuffer("z"); + ASSERT_TRUE(z.sumNumber().e(0) > 0); +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 51780fb2c..638cd8ac3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -6521,6 +6521,15 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, array2); } + @Test + public void testRndBloat16() { + INDArray x = Nd4j.rand(DataType.BFLOAT16 , 'c', new long[]{5}); + assertTrue(x.sumNumber().floatValue() > 0); + + x = Nd4j.randn(DataType.BFLOAT16 , 10); + assertTrue(x.sumNumber().floatValue() > 0); + } + @Test public void testLegacyDeserialization_2() throws Exception { val f = new ClassPathResource("legacy/NDArray_longshape_float.bin").getFile();