Fix an issue when creating DataBuffer/INDArray from ByteBuffer for multiple datatypes (#446)
* Fix missing dtypes when creating DataBuffer from ByteBuffer Signed-off-by: Alex Black <blacka101@gmail.com> * Revert LongIndexer -> ULongIndexer; fixes for UIntIndexer Signed-off-by: Alex Black <blacka101@gmail.com> * CUDA fix Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
91a2004d8f
commit
b786418c5d
|
@ -826,6 +826,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
return ((FloatIndexer) indexer).get(i);
|
return ((FloatIndexer) indexer).get(i);
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
return ((UIntIndexer) indexer).get(i);
|
||||||
case INT:
|
case INT:
|
||||||
return ((IntIndexer) indexer).get(i);
|
return ((IntIndexer) indexer).get(i);
|
||||||
case BFLOAT16:
|
case BFLOAT16:
|
||||||
|
@ -866,10 +867,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
return (long) ((Bfloat16Indexer) indexer).get(i);
|
return (long) ((Bfloat16Indexer) indexer).get(i);
|
||||||
case HALF:
|
case HALF:
|
||||||
return (long) ((HalfIndexer) indexer).get( i);
|
return (long) ((HalfIndexer) indexer).get( i);
|
||||||
case UINT64:
|
case UINT64: //Fall through
|
||||||
case LONG:
|
case LONG:
|
||||||
return ((LongIndexer) indexer).get(i);
|
return ((LongIndexer) indexer).get(i);
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
return (long) ((UIntIndexer) indexer).get(i);
|
||||||
case INT:
|
case INT:
|
||||||
return (long) ((IntIndexer) indexer).get(i);
|
return (long) ((IntIndexer) indexer).get(i);
|
||||||
case UINT16:
|
case UINT16:
|
||||||
|
@ -906,6 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0);
|
return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0);
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
return (short) ((UIntIndexer)indexer).get(i);
|
||||||
case INT:
|
case INT:
|
||||||
return (short) ((IntIndexer) indexer).get(i);
|
return (short) ((IntIndexer) indexer).get(i);
|
||||||
case UINT16:
|
case UINT16:
|
||||||
|
@ -943,6 +946,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f;
|
return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
return (float) ((UIntIndexer)indexer).get(i);
|
||||||
case INT:
|
case INT:
|
||||||
return (float) ((IntIndexer) indexer).get(i);
|
return (float) ((IntIndexer) indexer).get(i);
|
||||||
case UINT16:
|
case UINT16:
|
||||||
|
@ -957,7 +961,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
return (float) ((UByteIndexer) indexer).get(i);
|
return (float) ((UByteIndexer) indexer).get(i);
|
||||||
case BYTE:
|
case BYTE:
|
||||||
return (float) ((ByteIndexer) indexer).get(i);
|
return (float) ((ByteIndexer) indexer).get(i);
|
||||||
case UINT64:
|
case UINT64: //Fall through
|
||||||
case LONG:
|
case LONG:
|
||||||
return (float) ((LongIndexer) indexer).get(i);
|
return (float) ((LongIndexer) indexer).get(i);
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
|
@ -978,6 +982,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
return ((BooleanIndexer) indexer).get(i) ? 1 : 0;
|
return ((BooleanIndexer) indexer).get(i) ? 1 : 0;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
return (int)((UIntIndexer) indexer).get(i);
|
||||||
case INT:
|
case INT:
|
||||||
return ((IntIndexer) indexer).get(i);
|
return ((IntIndexer) indexer).get(i);
|
||||||
case BFLOAT16:
|
case BFLOAT16:
|
||||||
|
@ -992,7 +997,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
return ((UByteIndexer) indexer).get(i);
|
return ((UByteIndexer) indexer).get(i);
|
||||||
case BYTE:
|
case BYTE:
|
||||||
return ((ByteIndexer) indexer).get(i);
|
return ((ByteIndexer) indexer).get(i);
|
||||||
case UINT64:
|
case UINT64: //Fall through
|
||||||
case LONG:
|
case LONG:
|
||||||
return (int) ((LongIndexer) indexer).get(i);
|
return (int) ((LongIndexer) indexer).get(i);
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
|
@ -1058,6 +1063,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
((ShortIndexer) indexer).put(i, (short) element);
|
((ShortIndexer) indexer).put(i, (short) element);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
((UIntIndexer) indexer).put(i, (long)element);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
((IntIndexer) indexer).put(i, (int) element);
|
((IntIndexer) indexer).put(i, (int) element);
|
||||||
break;
|
break;
|
||||||
|
@ -1104,6 +1111,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
((ShortIndexer) indexer).put(i, (short) element);
|
((ShortIndexer) indexer).put(i, (short) element);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
((UIntIndexer) indexer).put(i, (long)element);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
((IntIndexer) indexer).put(i, (int) element);
|
((IntIndexer) indexer).put(i, (int) element);
|
||||||
break;
|
break;
|
||||||
|
@ -1150,10 +1159,12 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
((ShortIndexer) indexer).put(i, (short) element);
|
((ShortIndexer) indexer).put(i, (short) element);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
((UIntIndexer) indexer).put(i, element);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
((IntIndexer) indexer).put(i, element);
|
((IntIndexer) indexer).put(i, element);
|
||||||
break;
|
break;
|
||||||
case UINT64:
|
case UINT64: //Fall through
|
||||||
case LONG:
|
case LONG:
|
||||||
((LongIndexer) indexer).put(i, element);
|
((LongIndexer) indexer).put(i, element);
|
||||||
break;
|
break;
|
||||||
|
@ -1195,8 +1206,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
case SHORT:
|
case SHORT:
|
||||||
((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0);
|
((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0);
|
||||||
break;
|
break;
|
||||||
case INT:
|
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
((UIntIndexer) indexer).put(i, element ? 1 : 0);
|
||||||
|
break;
|
||||||
|
case INT:
|
||||||
((IntIndexer) indexer).put(i, element ? 1 : 0);
|
((IntIndexer) indexer).put(i, element ? 1 : 0);
|
||||||
break;
|
break;
|
||||||
case UINT64:
|
case UINT64:
|
||||||
|
@ -1242,6 +1255,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
((ShortIndexer) indexer).put(i, (short) element);
|
((ShortIndexer) indexer).put(i, (short) element);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
((UIntIndexer) indexer).put(i, element);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
((IntIndexer) indexer).put(i, (int) element);
|
((IntIndexer) indexer).put(i, (int) element);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -324,6 +324,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
indexer = FloatIndexer.create((FloatPointer) pointer);
|
indexer = FloatIndexer.create((FloatPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
|
this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
@ -336,7 +339,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
|
this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer();
|
||||||
indexer = HalfIndexer.create((ShortPointer) pointer);
|
indexer = HalfIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT64:
|
case UINT64: //Fall through
|
||||||
case LONG:
|
case LONG:
|
||||||
this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer();
|
this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer();
|
||||||
indexer = LongIndexer.create((LongPointer) pointer);
|
indexer = LongIndexer.create((LongPointer) pointer);
|
||||||
|
@ -501,6 +504,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
indexer = FloatIndexer.create((FloatPointer) pointer);
|
indexer = FloatIndexer.create((FloatPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
|
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
@ -513,7 +519,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
|
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer();
|
||||||
indexer = HalfIndexer.create((ShortPointer) pointer);
|
indexer = HalfIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT64:
|
case UINT64: //Fall through
|
||||||
case LONG:
|
case LONG:
|
||||||
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer();
|
this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer();
|
||||||
indexer = LongIndexer.create((LongPointer) pointer);
|
indexer = LongIndexer.create((LongPointer) pointer);
|
||||||
|
|
|
@ -121,6 +121,24 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer();
|
||||||
|
|
||||||
setIndexer(ByteIndexer.create((BytePointer) pointer));
|
setIndexer(ByteIndexer.create((BytePointer) pointer));
|
||||||
|
} else if(dataType() == DataType.FLOAT16){
|
||||||
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
|
||||||
|
setIndexer(HalfIndexer.create((ShortPointer) pointer));
|
||||||
|
} else if(dataType() == DataType.BFLOAT16){
|
||||||
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
|
||||||
|
setIndexer(Bfloat16Indexer.create((ShortPointer) pointer));
|
||||||
|
} else if(dataType() == DataType.BOOL){
|
||||||
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer();
|
||||||
|
setIndexer(BooleanIndexer.create((BooleanPointer) pointer));
|
||||||
|
} else if(dataType() == DataType.UINT16){
|
||||||
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer();
|
||||||
|
setIndexer(UShortIndexer.create((ShortPointer) pointer));
|
||||||
|
} else if(dataType() == DataType.UINT32){
|
||||||
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
|
||||||
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
|
} else if (dataType() == DataType.UINT64) {
|
||||||
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
|
||||||
|
setIndexer(LongIndexer.create((LongPointer) pointer));
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4j.getDeallocatorService().pickObject(this);
|
Nd4j.getDeallocatorService().pickObject(this);
|
||||||
|
@ -336,15 +354,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
} else if (dataType() == DataType.UINT32) {
|
} else if (dataType() == DataType.UINT32) {
|
||||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer();
|
||||||
|
|
||||||
// FIXME: we need unsigned indexer here
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
|
||||||
|
|
||||||
if (initialize)
|
if (initialize)
|
||||||
fillPointerWithZero();
|
fillPointerWithZero();
|
||||||
} else if (dataType() == DataType.UINT64) {
|
} else if (dataType() == DataType.UINT64) {
|
||||||
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
|
pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer();
|
||||||
|
|
||||||
// FIXME: we need unsigned indexer here
|
|
||||||
setIndexer(LongIndexer.create((LongPointer) pointer));
|
setIndexer(LongIndexer.create((LongPointer) pointer));
|
||||||
|
|
||||||
if (initialize)
|
if (initialize)
|
||||||
|
@ -500,7 +516,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
|
|
||||||
// FIXME: need unsigned indexer here
|
// FIXME: need unsigned indexer here
|
||||||
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
|
|
||||||
} else if (dataType() == DataType.UINT64) {
|
} else if (dataType() == DataType.UINT64) {
|
||||||
attached = true;
|
attached = true;
|
||||||
|
|
|
@ -8395,6 +8395,25 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(e, z);
|
assertEquals(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateBufferFromByteBuffer(){
|
||||||
|
|
||||||
|
for(DataType dt : DataType.values()){
|
||||||
|
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
// System.out.println(dt);
|
||||||
|
|
||||||
|
int lengthBytes = 256;
|
||||||
|
int lengthElements = lengthBytes / dt.width();
|
||||||
|
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
|
||||||
|
|
||||||
|
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
|
||||||
|
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
|
||||||
|
|
||||||
|
arr.toStringFull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue