rl4j: update host pointers content before reading them

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-03-11 10:57:55 +03:00
parent 6aaca58506
commit a7a97d8259
1 changed files with 8 additions and 0 deletions

View File

@ -23,6 +23,7 @@ import org.bytedeco.javacv.*;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.rl4j.util.VideoRecorder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -90,6 +91,9 @@ public class HistoryProcessor implements IHistoryProcessor {
public void record(INDArray raw) {
if(isMonitoring()) {
// before accessing the raw pointer, we need to make sure that array is actual on the host side
Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST);
VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer());
try {
videoRecorder.record(frame);
@ -110,6 +114,10 @@ public class HistoryProcessor implements IHistoryProcessor {
private INDArray transform(INDArray raw) {
long[] shape = raw.shape();
// before accessing the raw pointer, we need to make sure that array is actual on the host side
Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST);
Mat ocvmat = new Mat((int)shape[0], (int)shape[1], CV_32FC(3), raw.data().pointer());
Mat cvmat = new Mat(shape[0], shape[1], CV_8UC(3));
ocvmat.convertTo(cvmat, CV_8UC(3), 255.0, 0.0);