diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java index 243422b13..550c6eb70 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java @@ -23,6 +23,7 @@ import org.bytedeco.javacv.*; import org.datavec.image.loader.NativeImageLoader; import org.deeplearning4j.rl4j.util.VideoRecorder; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -90,6 +91,9 @@ public class HistoryProcessor implements IHistoryProcessor { public void record(INDArray raw) { if(isMonitoring()) { + // before accessing the raw pointer, we need to make sure that array is actual on the host side + Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST); + VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer()); try { videoRecorder.record(frame); @@ -110,6 +114,10 @@ public class HistoryProcessor implements IHistoryProcessor { private INDArray transform(INDArray raw) { long[] shape = raw.shape(); + + // before accessing the raw pointer, we need to make sure that array is actual on the host side + Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST); + Mat ocvmat = new Mat((int)shape[0], (int)shape[1], CV_32FC(3), raw.data().pointer()); Mat cvmat = new Mat(shape[0], shape[1], CV_8UC(3)); ocvmat.convertTo(cvmat, CV_8UC(3), 255.0, 0.0);