parent
e0f8d86eac
commit
c7277729e9
|
@ -344,7 +344,7 @@ bool NDArray::isS() const {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::isR() const {
|
||||
auto xType = ArrayOptions::dataType(this->_shapeInfo);
|
||||
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8;
|
||||
return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -276,23 +276,6 @@ namespace functions {
|
|||
DISPATCH_BY_OPNUM_T(execTransform, PARAMS(state, z, zShapeInfo, extraArguments), RANDOM_OPS)
|
||||
}
|
||||
|
||||
// FIXME: eventually we might want to get rid of that
|
||||
#ifndef __CLION_IDE__
|
||||
/*
|
||||
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *y, Nd4jLong *yShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *y, Nd4jLong *yShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *y, Nd4jLong *yShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
|
||||
|
||||
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *x, Nd4jLong *xShapeInfo, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *x, Nd4jLong *xShapeInfo, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *x, Nd4jLong *xShapeInfo, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
|
||||
|
||||
BUILD_CALL_1(template void RandomFunction<float>::execTransform, float, (Nd4jPointer state, float *z, Nd4jLong *zShapeInfo, float *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<float16>::execTransform, float16, (Nd4jPointer state, float16 *z, Nd4jLong *zShapeInfo, float16 *extraArguments), RANDOM_OPS)
|
||||
BUILD_CALL_1(template void RandomFunction<double>::execTransform, double, (Nd4jPointer state, double *z, Nd4jLong *zShapeInfo, double *extraArguments), RANDOM_OPS)
|
||||
*/
|
||||
#endif
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT RandomFunction, , FLOAT_TYPES);
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
#include <ops/declarable/helpers/col2im.h>
|
||||
#include <helpers/RandomLauncher.h>
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
|
@ -104,6 +105,14 @@ TEST_F(DataTypesValidationTests, Basic_Test_4) {
|
|||
ASSERT_EQ(ND4J_STATUS_VALIDATION, result);
|
||||
}
|
||||
|
||||
TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) {
|
||||
auto x = NDArrayFactory::create<bfloat16>('c', {5, 10});
|
||||
RandomGenerator gen(119, 120);
|
||||
RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6);
|
||||
|
||||
ASSERT_TRUE(x.sumNumber().e<float>(0) > 0);
|
||||
}
|
||||
|
||||
TEST_F(DataTypesValidationTests, cast_1) {
|
||||
|
||||
float16 x = static_cast<float16>(1.f);
|
||||
|
|
|
@ -1212,6 +1212,18 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, test_bfloat16_rng) {
|
||||
if (!Environment::getInstance()->isCPU())
|
||||
return;
|
||||
|
||||
auto z = NDArrayFactory::create<bfloat16>('c', {10});
|
||||
RandomGenerator rng(119, 323841120L);
|
||||
bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f};
|
||||
execRandom(nullptr, nd4j::random::Ops::UniformDistribution, &rng, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), args);
|
||||
z.printIndexedBuffer("z");
|
||||
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||
|
|
|
@ -6521,6 +6521,15 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertEquals(exp, array2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRndBloat16() {
|
||||
INDArray x = Nd4j.rand(DataType.BFLOAT16 , 'c', new long[]{5});
|
||||
assertTrue(x.sumNumber().floatValue() > 0);
|
||||
|
||||
x = Nd4j.randn(DataType.BFLOAT16 , 10);
|
||||
assertTrue(x.sumNumber().floatValue() > 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLegacyDeserialization_2() throws Exception {
|
||||
val f = new ClassPathResource("legacy/NDArray_longshape_float.bin").getFile();
|
||||
|
|
Loading…
Reference in New Issue