Fix an issue when creating DataBuffer/INDArray from ByteBuffer for multiple datatypes (#446)

* Fix missing dtypes when creating DataBuffer from ByteBuffer

Signed-off-by: Alex Black <blacka101@gmail.com>

* Revert LongIndexer -> ULongIndexer; fixes for UIntIndexer

Signed-off-by: Alex Black <blacka101@gmail.com>

* CUDA fix

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-11 21:29:52 +10:00 committed by GitHub
parent 91a2004d8f
commit b786418c5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 11 deletions

View File

@ -826,6 +826,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case FLOAT: case FLOAT:
return ((FloatIndexer) indexer).get(i); return ((FloatIndexer) indexer).get(i);
case UINT32: case UINT32:
return ((UIntIndexer) indexer).get(i);
case INT: case INT:
return ((IntIndexer) indexer).get(i); return ((IntIndexer) indexer).get(i);
case BFLOAT16: case BFLOAT16:
@ -866,10 +867,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
return (long) ((Bfloat16Indexer) indexer).get(i); return (long) ((Bfloat16Indexer) indexer).get(i);
case HALF: case HALF:
return (long) ((HalfIndexer) indexer).get( i); return (long) ((HalfIndexer) indexer).get( i);
case UINT64: case UINT64: //Fall through
case LONG: case LONG:
return ((LongIndexer) indexer).get(i); return ((LongIndexer) indexer).get(i);
case UINT32: case UINT32:
return (long) ((UIntIndexer) indexer).get(i);
case INT: case INT:
return (long) ((IntIndexer) indexer).get(i); return (long) ((IntIndexer) indexer).get(i);
case UINT16: case UINT16:
@ -906,6 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL: case BOOL:
return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0); return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0);
case UINT32: case UINT32:
return (short) ((UIntIndexer)indexer).get(i);
case INT: case INT:
return (short) ((IntIndexer) indexer).get(i); return (short) ((IntIndexer) indexer).get(i);
case UINT16: case UINT16:
@ -943,6 +946,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL: case BOOL:
return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f; return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f;
case UINT32: case UINT32:
return (float) ((UIntIndexer)indexer).get(i);
case INT: case INT:
return (float) ((IntIndexer) indexer).get(i); return (float) ((IntIndexer) indexer).get(i);
case UINT16: case UINT16:
@ -957,7 +961,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
return (float) ((UByteIndexer) indexer).get(i); return (float) ((UByteIndexer) indexer).get(i);
case BYTE: case BYTE:
return (float) ((ByteIndexer) indexer).get(i); return (float) ((ByteIndexer) indexer).get(i);
case UINT64: case UINT64: //Fall through
case LONG: case LONG:
return (float) ((LongIndexer) indexer).get(i); return (float) ((LongIndexer) indexer).get(i);
case FLOAT: case FLOAT:
@ -978,6 +982,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL: case BOOL:
return ((BooleanIndexer) indexer).get(i) ? 1 : 0; return ((BooleanIndexer) indexer).get(i) ? 1 : 0;
case UINT32: case UINT32:
return (int)((UIntIndexer) indexer).get(i);
case INT: case INT:
return ((IntIndexer) indexer).get(i); return ((IntIndexer) indexer).get(i);
case BFLOAT16: case BFLOAT16:
@ -992,7 +997,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
return ((UByteIndexer) indexer).get(i); return ((UByteIndexer) indexer).get(i);
case BYTE: case BYTE:
return ((ByteIndexer) indexer).get(i); return ((ByteIndexer) indexer).get(i);
case UINT64: case UINT64: //Fall through
case LONG: case LONG:
return (int) ((LongIndexer) indexer).get(i); return (int) ((LongIndexer) indexer).get(i);
case FLOAT: case FLOAT:
@ -1058,6 +1063,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element); ((ShortIndexer) indexer).put(i, (short) element);
break; break;
case UINT32: case UINT32:
((UIntIndexer) indexer).put(i, (long)element);
break;
case INT: case INT:
((IntIndexer) indexer).put(i, (int) element); ((IntIndexer) indexer).put(i, (int) element);
break; break;
@ -1104,6 +1111,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element); ((ShortIndexer) indexer).put(i, (short) element);
break; break;
case UINT32: case UINT32:
((UIntIndexer) indexer).put(i, (long)element);
break;
case INT: case INT:
((IntIndexer) indexer).put(i, (int) element); ((IntIndexer) indexer).put(i, (int) element);
break; break;
@ -1150,10 +1159,12 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element); ((ShortIndexer) indexer).put(i, (short) element);
break; break;
case UINT32: case UINT32:
((UIntIndexer) indexer).put(i, element);
break;
case INT: case INT:
((IntIndexer) indexer).put(i, element); ((IntIndexer) indexer).put(i, element);
break; break;
case UINT64: case UINT64: //Fall through
case LONG: case LONG:
((LongIndexer) indexer).put(i, element); ((LongIndexer) indexer).put(i, element);
break; break;
@ -1195,8 +1206,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
case SHORT: case SHORT:
((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0); ((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0);
break; break;
case INT:
case UINT32: case UINT32:
((UIntIndexer) indexer).put(i, element ? 1 : 0);
break;
case INT:
((IntIndexer) indexer).put(i, element ? 1 : 0); ((IntIndexer) indexer).put(i, element ? 1 : 0);
break; break;
case UINT64: case UINT64:
@ -1242,6 +1255,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element); ((ShortIndexer) indexer).put(i, (short) element);
break; break;
case UINT32: case UINT32:
((UIntIndexer) indexer).put(i, element);
break;
case INT: case INT:
((IntIndexer) indexer).put(i, (int) element); ((IntIndexer) indexer).put(i, (int) element);
break; break;

View File

@ -324,6 +324,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = FloatIndexer.create((FloatPointer) pointer); indexer = FloatIndexer.create((FloatPointer) pointer);
break; break;
case UINT32: case UINT32:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT: case INT:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer(); this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer); indexer = IntIndexer.create((IntPointer) pointer);
@ -336,7 +339,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer); indexer = HalfIndexer.create((ShortPointer) pointer);
break; break;
case UINT64: case UINT64: //Fall through
case LONG: case LONG:
this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer(); this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer); indexer = LongIndexer.create((LongPointer) pointer);
@ -501,6 +504,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = FloatIndexer.create((FloatPointer) pointer); indexer = FloatIndexer.create((FloatPointer) pointer);
break; break;
case UINT32: case UINT32:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT: case INT:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer(); this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer); indexer = IntIndexer.create((IntPointer) pointer);
@ -513,7 +519,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer); indexer = HalfIndexer.create((ShortPointer) pointer);
break; break;
case UINT64: case UINT64: //Fall through
case LONG: case LONG:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer(); this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer); indexer = LongIndexer.create((LongPointer) pointer);

View File

@ -121,6 +121,24 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer)); setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if(dataType() == DataType.FLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(HalfIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BFLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BOOL){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
} else if(dataType() == DataType.UINT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(UShortIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.UINT32){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
} }
Nd4j.getDeallocatorService().pickObject(this); Nd4j.getDeallocatorService().pickObject(this);
@ -336,15 +354,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
} else if (dataType() == DataType.UINT32) { } else if (dataType() == DataType.UINT32) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
// FIXME: we need unsigned indexer here setIndexer(UIntIndexer.create((IntPointer) pointer));
setIndexer(IntIndexer.create((IntPointer) pointer));
if (initialize) if (initialize)
fillPointerWithZero(); fillPointerWithZero();
} else if (dataType() == DataType.UINT64) { } else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
// FIXME: we need unsigned indexer here
setIndexer(LongIndexer.create((LongPointer) pointer)); setIndexer(LongIndexer.create((LongPointer) pointer));
if (initialize) if (initialize)
@ -500,7 +516,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
// FIXME: need unsigned indexer here // FIXME: need unsigned indexer here
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
setIndexer(IntIndexer.create((IntPointer) pointer)); setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) { } else if (dataType() == DataType.UINT64) {
attached = true; attached = true;

View File

@ -8395,6 +8395,25 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(e, z); assertEquals(e, z);
} }
@Test
public void testCreateBufferFromByteBuffer(){
for(DataType dt : DataType.values()){
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
continue;
// System.out.println(dt);
int lengthBytes = 256;
int lengthElements = lengthBytes / dt.width();
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
arr.toStringFull();
}
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';