alloc check for RNG (#179)
* missing alloc validation in RandomGenerator for CUDA Signed-off-by: raver119 <raver119@gmail.com> * set error message if rng alloc failed Signed-off-by: raver119 <raver119@gmail.com> * check for error code during RNG creation in java Signed-off-by: raver119 <raver119@gmail.com>master
parent
25db3a44f1
commit
256c9d20b0
|
@ -3602,7 +3602,13 @@ void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||||
|
try {
|
||||||
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
|
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <array/DataTypeUtils.h>
|
#include <array/DataTypeUtils.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
|
@ -46,7 +47,10 @@ namespace nd4j {
|
||||||
public:
|
public:
|
||||||
void *operator new(size_t len) {
|
void *operator new(size_t len) {
|
||||||
void *ptr;
|
void *ptr;
|
||||||
cudaHostAlloc(&ptr, len, cudaHostAllocDefault);
|
auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault);
|
||||||
|
if (res != 0)
|
||||||
|
throw std::runtime_error("CudaManagedRandomGenerator: failed to allocate memory");
|
||||||
|
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,10 @@ public class CudaNativeRandom extends NativeRandom {
|
||||||
public void init() {
|
public void init() {
|
||||||
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
||||||
|
|
||||||
|
if (nativeOps.lastErrorCode() != 0)
|
||||||
|
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||||
|
|
||||||
setSeed(seed);
|
setSeed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue