alloc check for RNG (#179)

* missing alloc validation in RandomGenerator for CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* set error message if rng alloc failed

Signed-off-by: raver119 <raver119@gmail.com>

* check for error code during RNG creation in java

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-23 09:51:02 +03:00 committed by GitHub
parent 25db3a44f1
commit 256c9d20b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 2 deletions

View File

@ -3602,7 +3602,13 @@ void deleteGraphContext(nd4j::graph::Context* ptr) {
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); try {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
} catch (std::exception &e) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
return nullptr;
}
} }
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) { Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {

View File

@ -28,6 +28,7 @@
#include <chrono> #include <chrono>
#include <array/DataTypeUtils.h> #include <array/DataTypeUtils.h>
#include <helpers/logger.h> #include <helpers/logger.h>
#include <stdexcept>
#ifdef __CUDACC__ #ifdef __CUDACC__
#include <cuda.h> #include <cuda.h>
@ -46,7 +47,10 @@ namespace nd4j {
public: public:
void *operator new(size_t len) { void *operator new(size_t len) {
void *ptr; void *ptr;
cudaHostAlloc(&ptr, len, cudaHostAllocDefault); auto res = cudaHostAlloc(&ptr, len, cudaHostAllocDefault);
if (res != 0)
throw std::runtime_error("CudaManagedRandomGenerator: failed to allocate memory");
return ptr; return ptr;
} }

View File

@ -54,6 +54,10 @@ public class CudaNativeRandom extends NativeRandom {
public void init() { public void init() {
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef); statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
setSeed(seed); setSeed(seed);
} }