datavec python ensure host (#113)

* ensure host

* one more host ensure

* info->debug
master
Fariz Rahman 2019-12-05 17:57:32 +05:30 committed by Alex Black
parent ef4d3ffee8
commit 0e8a4f77bc
2 changed files with 7 additions and 3 deletions

View File

@ -21,6 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -60,6 +61,7 @@ public class NumpyArray {
setND4JArray(); setND4JArray();
if (copy){ if (copy){
nd4jArray = nd4jArray.dup(); nd4jArray = nd4jArray.dup();
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
this.address = nd4jArray.data().address(); this.address = nd4jArray.data().address();
} }
@ -85,6 +87,7 @@ public class NumpyArray {
setND4JArray(); setND4JArray();
if (copy){ if (copy){
nd4jArray = nd4jArray.dup(); nd4jArray = nd4jArray.dup();
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
this.address = nd4jArray.data().address(); this.address = nd4jArray.data().address();
} }
} }
@ -104,11 +107,12 @@ public class NumpyArray {
nd4jStrides[i] = strides[i] / elemSize; nd4jStrides[i] = strides[i] / elemSize;
} }
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
} }
public NumpyArray(INDArray nd4jArray){ public NumpyArray(INDArray nd4jArray){
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
DataBuffer buff = nd4jArray.data(); DataBuffer buff = nd4jArray.data();
address = buff.pointer().address(); address = buff.pointer().address();
shape = nd4jArray.shape(); shape = nd4jArray.shape();

View File

@ -605,7 +605,7 @@ public class PythonExecutioner {
private static synchronized void _exec(String code) { private static synchronized void _exec(String code) {
log.info(code); log.debug(code);
log.info("CPython: PyRun_SimpleStringFlag()"); log.info("CPython: PyRun_SimpleStringFlag()");
int result = PyRun_SimpleStringFlags(code, null); int result = PyRun_SimpleStringFlags(code, null);