- numpy import fix for CUDA (#64)
- skip tagLocation for empty arrays Signed-off-by: raver119 <raver119@gmail.com>master
parent
c9e867b2e8
commit
c499dc962f
|
@ -529,7 +529,7 @@ public class AtomicAllocator implements Allocator {
|
|||
* @param objectId
|
||||
* @return
|
||||
*/
|
||||
protected AllocationPoint getAllocationPoint(Long objectId) {
|
||||
protected AllocationPoint getAllocationPoint(@NonNull Long objectId) {
|
||||
return allocationsMap.get(objectId);
|
||||
}
|
||||
|
||||
|
|
|
@ -339,6 +339,10 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
*/
|
||||
@Override
|
||||
public void tagLocation(INDArray array, Location location) {
|
||||
// we can't tag empty arrays.
|
||||
if (array.isEmpty())
|
||||
return;
|
||||
|
||||
if (location == Location.HOST)
|
||||
AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
|
||||
else if (location == Location.DEVICE)
|
||||
|
|
|
@ -116,6 +116,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
|
||||
//cuda specific bits
|
||||
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
|
||||
this.trackingPoint = allocationPoint.getObjectId();
|
||||
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
|
||||
|
@ -124,41 +125,20 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
|
||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
|
||||
if (allocationPoint.getHostPointer() != null) {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
|
||||
} else {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
||||
}
|
||||
|
||||
context.getSpecialStream().synchronize();
|
||||
|
||||
if (allocationPoint.getHostPointer() != null)
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
|
||||
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
||||
|
||||
this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length * getElementSize(), 0);
|
||||
|
||||
switch (dataType()) {
|
||||
case INT: {
|
||||
setIndexer(IntIndexer.create(((CudaPointer) this.pointer).asIntPointer()));
|
||||
}
|
||||
break;
|
||||
case FLOAT: {
|
||||
setIndexer(FloatIndexer.create(((CudaPointer) this.pointer).asFloatPointer()));
|
||||
}
|
||||
break;
|
||||
case DOUBLE: {
|
||||
setIndexer(DoubleIndexer.create(((CudaPointer) this.pointer).asDoublePointer()));
|
||||
}
|
||||
break;
|
||||
case HALF: {
|
||||
setIndexer(ShortIndexer.create(((CudaPointer) this.pointer).asShortPointer()));
|
||||
}
|
||||
break;
|
||||
case LONG: {
|
||||
setIndexer(LongIndexer.create(((CudaPointer) this.pointer).asLongPointer()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
this.trackingPoint = allocationPoint.getObjectId();
|
||||
|
||||
}
|
||||
|
||||
public BaseCudaDataBuffer(float[] data, boolean copy) {
|
||||
|
|
|
@ -310,6 +310,13 @@ public class NumpyFormatTests extends BaseNd4jTest {
|
|||
INDArray act1 = Nd4j.createFromNpyFile(f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAbsentNumpyFile_2() throws Exception {
|
||||
val f = new File("c:/develop/batch-x-1.npy");
|
||||
INDArray act1 = Nd4j.createFromNpyFile(f);
|
||||
log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue