- numpy import fix for CUDA (#64)

- skip tagLocation for empty arrays

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-17 15:19:38 +03:00 committed by AlexDBlack
parent c9e867b2e8
commit c499dc962f
4 changed files with 22 additions and 31 deletions

View File

@ -529,7 +529,7 @@ public class AtomicAllocator implements Allocator {
* @param objectId
* @return
*/
protected AllocationPoint getAllocationPoint(Long objectId) {
protected AllocationPoint getAllocationPoint(@NonNull Long objectId) {
return allocationsMap.get(objectId);
}

View File

@ -339,6 +339,10 @@ public class CudaAffinityManager extends BasicAffinityManager {
*/
@Override
public void tagLocation(INDArray array, Location location) {
// we can't tag empty arrays.
if (array.isEmpty())
return;
if (location == Location.HOST)
AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
else if (location == Location.DEVICE)

View File

@ -116,6 +116,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
//cuda specific bits
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
this.trackingPoint = allocationPoint.getObjectId();
Nd4j.getDeallocatorService().pickObject(this);
@ -124,41 +125,20 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
if (allocationPoint.getHostPointer() != null) {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
} else {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
}
context.getSpecialStream().synchronize();
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
if (allocationPoint.getHostPointer() != null)
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length * getElementSize(), 0);
switch (dataType()) {
case INT: {
setIndexer(IntIndexer.create(((CudaPointer) this.pointer).asIntPointer()));
}
break;
case FLOAT: {
setIndexer(FloatIndexer.create(((CudaPointer) this.pointer).asFloatPointer()));
}
break;
case DOUBLE: {
setIndexer(DoubleIndexer.create(((CudaPointer) this.pointer).asDoublePointer()));
}
break;
case HALF: {
setIndexer(ShortIndexer.create(((CudaPointer) this.pointer).asShortPointer()));
}
break;
case LONG: {
setIndexer(LongIndexer.create(((CudaPointer) this.pointer).asLongPointer()));
}
break;
}
this.trackingPoint = allocationPoint.getObjectId();
}
public BaseCudaDataBuffer(float[] data, boolean copy) {

View File

@ -310,6 +310,13 @@ public class NumpyFormatTests extends BaseNd4jTest {
INDArray act1 = Nd4j.createFromNpyFile(f);
}
@Test
public void testAbsentNumpyFile_2() throws Exception {
val f = new File("c:/develop/batch-x-1.npy");
INDArray act1 = Nd4j.createFromNpyFile(f);
log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue());
}
@Override
public char ordering() {
return 'c';