Fix an issue when creating DataBuffer/INDArray from ByteBuffer for multiple datatypes (#446)

* Fix missing dtypes when creating DataBuffer from ByteBuffer

Signed-off-by: Alex Black <blacka101@gmail.com>

* Revert LongIndexer -> ULongIndexer; fixes for UIntIndexer

Signed-off-by: Alex Black <blacka101@gmail.com>

* CUDA fix

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-11 21:29:52 +10:00 committed by GitHub
parent 91a2004d8f
commit b786418c5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 11 deletions

View File

@ -826,6 +826,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case FLOAT:
return ((FloatIndexer) indexer).get(i);
case UINT32:
return ((UIntIndexer) indexer).get(i);
case INT:
return ((IntIndexer) indexer).get(i);
case BFLOAT16:
@ -866,10 +867,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
return (long) ((Bfloat16Indexer) indexer).get(i);
case HALF:
return (long) ((HalfIndexer) indexer).get( i);
case UINT64:
case UINT64: //Fall through
case LONG:
return ((LongIndexer) indexer).get(i);
case UINT32:
return (long) ((UIntIndexer) indexer).get(i);
case INT:
return (long) ((IntIndexer) indexer).get(i);
case UINT16:
@ -906,6 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL:
return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0);
case UINT32:
return (short) ((UIntIndexer)indexer).get(i);
case INT:
return (short) ((IntIndexer) indexer).get(i);
case UINT16:
@ -943,6 +946,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL:
return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f;
case UINT32:
return (float) ((UIntIndexer)indexer).get(i);
case INT:
return (float) ((IntIndexer) indexer).get(i);
case UINT16:
@ -957,7 +961,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
return (float) ((UByteIndexer) indexer).get(i);
case BYTE:
return (float) ((ByteIndexer) indexer).get(i);
case UINT64:
case UINT64: //Fall through
case LONG:
return (float) ((LongIndexer) indexer).get(i);
case FLOAT:
@ -978,6 +982,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
case BOOL:
return ((BooleanIndexer) indexer).get(i) ? 1 : 0;
case UINT32:
return (int)((UIntIndexer) indexer).get(i);
case INT:
return ((IntIndexer) indexer).get(i);
case BFLOAT16:
@ -992,7 +997,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
return ((UByteIndexer) indexer).get(i);
case BYTE:
return ((ByteIndexer) indexer).get(i);
case UINT64:
case UINT64: //Fall through
case LONG:
return (int) ((LongIndexer) indexer).get(i);
case FLOAT:
@ -1058,6 +1063,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, (long)element);
break;
case INT:
((IntIndexer) indexer).put(i, (int) element);
break;
@ -1104,6 +1111,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, (long)element);
break;
case INT:
((IntIndexer) indexer).put(i, (int) element);
break;
@ -1150,10 +1159,12 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, element);
break;
case INT:
((IntIndexer) indexer).put(i, element);
break;
case UINT64:
case UINT64: //Fall through
case LONG:
((LongIndexer) indexer).put(i, element);
break;
@ -1195,8 +1206,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
case SHORT:
((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0);
break;
case INT:
case UINT32:
((UIntIndexer) indexer).put(i, element ? 1 : 0);
break;
case INT:
((IntIndexer) indexer).put(i, element ? 1 : 0);
break;
case UINT64:
@ -1242,6 +1255,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
((ShortIndexer) indexer).put(i, (short) element);
break;
case UINT32:
((UIntIndexer) indexer).put(i, element);
break;
case INT:
((IntIndexer) indexer).put(i, (int) element);
break;

View File

@ -324,6 +324,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case UINT32:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
@ -336,7 +339,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case UINT64:
case UINT64: //Fall through
case LONG:
this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
@ -501,6 +504,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case UINT32:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
@ -513,7 +519,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
indexer = HalfIndexer.create((ShortPointer) pointer);
break;
case UINT64:
case UINT64: //Fall through
case LONG:
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);

View File

@ -121,6 +121,24 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
setIndexer(ByteIndexer.create((BytePointer) pointer));
} else if(dataType() == DataType.FLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(HalfIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BFLOAT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.BOOL){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
} else if(dataType() == DataType.UINT16){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
setIndexer(UShortIndexer.create((ShortPointer) pointer));
} else if(dataType() == DataType.UINT32){
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
setIndexer(LongIndexer.create((LongPointer) pointer));
}
Nd4j.getDeallocatorService().pickObject(this);
@ -336,15 +354,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
} else if (dataType() == DataType.UINT32) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
// FIXME: we need unsigned indexer here
setIndexer(IntIndexer.create((IntPointer) pointer));
setIndexer(UIntIndexer.create((IntPointer) pointer));
if (initialize)
fillPointerWithZero();
} else if (dataType() == DataType.UINT64) {
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
// FIXME: we need unsigned indexer here
setIndexer(LongIndexer.create((LongPointer) pointer));
if (initialize)
@ -500,7 +516,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
// FIXME: need unsigned indexer here
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
setIndexer(IntIndexer.create((IntPointer) pointer));
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (dataType() == DataType.UINT64) {
attached = true;

View File

@ -8395,6 +8395,25 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(e, z);
}
@Test
public void testCreateBufferFromByteBuffer(){
for(DataType dt : DataType.values()){
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
continue;
// System.out.println(dt);
int lengthBytes = 256;
int lengthElements = lengthBytes / dt.width();
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
arr.toStringFull();
}
}
@Override
public char ordering() {
return 'c';