few fixes for bfloat16 in java and cpp (#114)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-14 21:51:42 +03:00 committed by GitHub
parent e0f8d86eac
commit c7277729e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 18 deletions

View File

@ -344,7 +344,7 @@ bool NDArray::isS() const {
//////////////////////////////////////////////////////////////////////////
bool NDArray::isR() const {
auto xType = ArrayOptions::dataType(this->_shapeInfo);
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8;
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16;
}
//////////////////////////////////////////////////////////////////////////

View File

@ -276,23 +276,6 @@ namespace functions {
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS)
}
// FIXME: eventually we might want to get rid of that
#ifndef __CLION_IDE__
/*
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *y, Nd4jLong *yShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *y, Nd4jLong *yShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *y, Nd4jLong *yShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
*/
#endif
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
}

View File

@ -27,6 +27,7 @@
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/col2im.h>
#include <helpers/RandomLauncher.h>
using namespace nd4j;
@ -104,6 +105,14 @@ TEST_F(DataTypesValidationTests, Basic_Test_4) {
ASSERT_EQ(ND4J_STATUS_VALIDATION, result);
}
TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) {
auto x = NDArrayFactory::create<bfloat16>('c', {5, 10});
RandomGenerator gen(119, 120);
RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6);
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
}
TEST_F(DataTypesValidationTests, cast_1) {
float16 x = static_cast<float16>(1.f);

View File

@ -1212,6 +1212,18 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
ASSERT_EQ(e, z);
}
TEST_F(JavaInteropTests, test_bfloat16_rng) {
if (!Environment::getInstance()->isCPU())
return;
auto z = NDArrayFactory::create<bfloat16>('c', {10});
RandomGenerator rng(119, 323841120L);
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args);
z.printIndexedBuffer("z");
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
}
/*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");

View File

@ -6521,6 +6521,15 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, array2);
}
@Test
public void testRndBloat16() {
INDArray x = Nd4j.rand(DataType.BFLOAT16 , 'c', new long[]{5});
assertTrue(x.sumNumber().floatValue() > 0);
x = Nd4j.randn(DataType.BFLOAT16 , 10);
assertTrue(x.sumNumber().floatValue() > 0);
}
@Test
public void testLegacyDeserialization_2() throws Exception {
val f = new ClassPathResource("legacy/NDArray_longshape_float.bin").getFile();