Fix an issue when creating DataBuffer/INDArray from ByteBuffer for multiple datatypes (#446)
* Fix missing dtypes when creating DataBuffer from ByteBuffer Signed-off-by: Alex Black <blacka101@gmail.com> * Revert LongIndexer -> ULongIndexer; fixes for UIntIndexer Signed-off-by: Alex Black <blacka101@gmail.com> * CUDA fix Signed-off-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									91a2004d8f
								
							
						
					
					
						commit
						b786418c5d
					
				| @ -826,6 +826,7 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|             case FLOAT: | ||||
|                 return ((FloatIndexer) indexer).get(i); | ||||
|             case UINT32: | ||||
|                 return ((UIntIndexer) indexer).get(i); | ||||
|             case INT: | ||||
|                 return ((IntIndexer) indexer).get(i); | ||||
|             case BFLOAT16: | ||||
| @ -866,10 +867,11 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|                 return (long) ((Bfloat16Indexer) indexer).get(i); | ||||
|             case HALF: | ||||
|                 return (long) ((HalfIndexer) indexer).get( i); | ||||
|             case UINT64: | ||||
|             case UINT64:    //Fall through | ||||
|             case LONG: | ||||
|                 return ((LongIndexer) indexer).get(i); | ||||
|             case UINT32: | ||||
|                 return (long) ((UIntIndexer) indexer).get(i); | ||||
|             case INT: | ||||
|                 return (long) ((IntIndexer) indexer).get(i); | ||||
|             case UINT16: | ||||
| @ -906,6 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|             case BOOL: | ||||
|                 return (short) (((BooleanIndexer) indexer).get(i) ? 1 : 0); | ||||
|             case UINT32: | ||||
|                 return (short) ((UIntIndexer)indexer).get(i); | ||||
|             case INT: | ||||
|                 return (short) ((IntIndexer) indexer).get(i); | ||||
|             case UINT16: | ||||
| @ -943,6 +946,7 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|             case BOOL: | ||||
|                 return ((BooleanIndexer) indexer).get(i) ? 1.f : 0.f; | ||||
|             case UINT32: | ||||
|                 return (float) ((UIntIndexer)indexer).get(i); | ||||
|             case INT: | ||||
|                 return (float) ((IntIndexer) indexer).get(i); | ||||
|             case UINT16: | ||||
| @ -957,7 +961,7 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|                 return (float) ((UByteIndexer) indexer).get(i); | ||||
|             case BYTE: | ||||
|                 return (float) ((ByteIndexer) indexer).get(i); | ||||
|             case UINT64: | ||||
|             case UINT64:  //Fall through | ||||
|             case LONG: | ||||
|                 return (float)  ((LongIndexer) indexer).get(i); | ||||
|             case FLOAT: | ||||
| @ -978,6 +982,7 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|             case BOOL: | ||||
|                 return ((BooleanIndexer) indexer).get(i) ? 1 : 0; | ||||
|             case UINT32: | ||||
|                 return (int)((UIntIndexer) indexer).get(i); | ||||
|             case INT: | ||||
|                 return ((IntIndexer) indexer).get(i); | ||||
|             case BFLOAT16: | ||||
| @ -992,7 +997,7 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|                 return ((UByteIndexer) indexer).get(i); | ||||
|             case BYTE: | ||||
|                 return ((ByteIndexer) indexer).get(i); | ||||
|             case UINT64: | ||||
|             case UINT64:  //Fall through | ||||
|             case LONG: | ||||
|                 return (int) ((LongIndexer) indexer).get(i); | ||||
|             case FLOAT: | ||||
| @ -1058,6 +1063,8 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|                 ((ShortIndexer) indexer).put(i,  (short) element); | ||||
|                 break; | ||||
|             case UINT32: | ||||
|                 ((UIntIndexer) indexer).put(i, (long)element); | ||||
|                 break; | ||||
|             case INT: | ||||
|                 ((IntIndexer) indexer).put(i, (int) element); | ||||
|                 break; | ||||
| @ -1104,6 +1111,8 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|                 ((ShortIndexer) indexer).put(i,  (short) element); | ||||
|                 break; | ||||
|             case UINT32: | ||||
|                 ((UIntIndexer) indexer).put(i, (long)element); | ||||
|                 break; | ||||
|             case INT: | ||||
|                 ((IntIndexer) indexer).put(i, (int) element); | ||||
|                 break; | ||||
| @ -1150,10 +1159,12 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|                 ((ShortIndexer) indexer).put(i,  (short) element); | ||||
|                 break; | ||||
|             case UINT32: | ||||
|                 ((UIntIndexer) indexer).put(i, element); | ||||
|                 break; | ||||
|             case INT: | ||||
|                 ((IntIndexer) indexer).put(i, element); | ||||
|                 break; | ||||
|             case UINT64: | ||||
|             case UINT64: //Fall through | ||||
|             case LONG: | ||||
|                 ((LongIndexer) indexer).put(i, element); | ||||
|                 break; | ||||
| @ -1195,8 +1206,10 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|             case SHORT: | ||||
|                 ((ShortIndexer) indexer).put(i, element ? (short) 1 : (short) 0); | ||||
|                 break; | ||||
|             case INT: | ||||
|             case UINT32: | ||||
|                 ((UIntIndexer) indexer).put(i, element ? 1 : 0); | ||||
|                 break; | ||||
|             case INT: | ||||
|                 ((IntIndexer) indexer).put(i, element ? 1 : 0); | ||||
|                 break; | ||||
|             case UINT64: | ||||
| @ -1242,6 +1255,8 @@ public abstract class BaseDataBuffer implements DataBuffer { | ||||
|                 ((ShortIndexer) indexer).put(i, (short) element); | ||||
|                 break; | ||||
|             case UINT32: | ||||
|                 ((UIntIndexer) indexer).put(i, element); | ||||
|                 break; | ||||
|             case INT: | ||||
|                 ((IntIndexer) indexer).put(i, (int) element); | ||||
|                 break; | ||||
|  | ||||
| @ -324,6 +324,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|                 indexer = FloatIndexer.create((FloatPointer) pointer); | ||||
|                 break; | ||||
|             case UINT32: | ||||
|                 this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer(); | ||||
|                 indexer = UIntIndexer.create((IntPointer) pointer); | ||||
|                 break; | ||||
|             case INT: | ||||
|                 this.pointer = new CudaPointer(hostPointer, length, 0).asIntPointer(); | ||||
|                 indexer = IntIndexer.create((IntPointer) pointer); | ||||
| @ -336,7 +339,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|                 this.pointer = new CudaPointer(hostPointer, length, 0).asShortPointer(); | ||||
|                 indexer = HalfIndexer.create((ShortPointer) pointer); | ||||
|                 break; | ||||
|             case UINT64: | ||||
|             case UINT64:    //Fall through | ||||
|             case LONG: | ||||
|                 this.pointer = new CudaPointer(hostPointer, length, 0).asLongPointer(); | ||||
|                 indexer = LongIndexer.create((LongPointer) pointer); | ||||
| @ -501,6 +504,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|                 indexer = FloatIndexer.create((FloatPointer) pointer); | ||||
|                 break; | ||||
|             case UINT32: | ||||
|                 this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer(); | ||||
|                 indexer = UIntIndexer.create((IntPointer) pointer); | ||||
|                 break; | ||||
|             case INT: | ||||
|                 this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asIntPointer(); | ||||
|                 indexer = IntIndexer.create((IntPointer) pointer); | ||||
| @ -513,7 +519,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|                 this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asShortPointer(); | ||||
|                 indexer = HalfIndexer.create((ShortPointer) pointer); | ||||
|                 break; | ||||
|             case UINT64: | ||||
|             case UINT64: //Fall through | ||||
|             case LONG: | ||||
|                 this.pointer = new CudaPointer(hostPointer, originalBuffer.length()).asLongPointer(); | ||||
|                 indexer = LongIndexer.create((LongPointer) pointer); | ||||
|  | ||||
| @ -121,6 +121,24 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBytePointer(); | ||||
| 
 | ||||
|             setIndexer(ByteIndexer.create((BytePointer) pointer)); | ||||
|         } else if(dataType() == DataType.FLOAT16){ | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); | ||||
|             setIndexer(HalfIndexer.create((ShortPointer) pointer)); | ||||
|         } else if(dataType() == DataType.BFLOAT16){ | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); | ||||
|             setIndexer(Bfloat16Indexer.create((ShortPointer) pointer)); | ||||
|         } else if(dataType() == DataType.BOOL){ | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asBoolPointer(); | ||||
|             setIndexer(BooleanIndexer.create((BooleanPointer) pointer)); | ||||
|         } else if(dataType() == DataType.UINT16){ | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asShortPointer(); | ||||
|             setIndexer(UShortIndexer.create((ShortPointer) pointer)); | ||||
|         } else if(dataType() == DataType.UINT32){ | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); | ||||
|             setIndexer(UIntIndexer.create((IntPointer) pointer)); | ||||
|         } else if (dataType() == DataType.UINT64) { | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); | ||||
|             setIndexer(LongIndexer.create((LongPointer) pointer)); | ||||
|         } | ||||
| 
 | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
| @ -336,15 +354,13 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
|         } else if (dataType() == DataType.UINT32) { | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asIntPointer(); | ||||
| 
 | ||||
|             // FIXME: we need unsigned indexer here | ||||
|             setIndexer(IntIndexer.create((IntPointer) pointer)); | ||||
|             setIndexer(UIntIndexer.create((IntPointer) pointer)); | ||||
| 
 | ||||
|             if (initialize) | ||||
|                 fillPointerWithZero(); | ||||
|         } else if (dataType() == DataType.UINT64) { | ||||
|             pointer = new PagedPointer(ptrDataBuffer.primaryBuffer(), length).asLongPointer(); | ||||
| 
 | ||||
|             // FIXME: we need unsigned indexer here | ||||
|             setIndexer(LongIndexer.create((LongPointer) pointer)); | ||||
| 
 | ||||
|             if (initialize) | ||||
| @ -500,7 +516,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
| 
 | ||||
|             // FIXME: need unsigned indexer here | ||||
|             pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); | ||||
|             setIndexer(IntIndexer.create((IntPointer) pointer)); | ||||
|             setIndexer(UIntIndexer.create((IntPointer) pointer)); | ||||
| 
 | ||||
|         } else if (dataType() == DataType.UINT64) { | ||||
|             attached = true; | ||||
|  | ||||
| @ -8395,6 +8395,25 @@ public class Nd4jTestsC extends BaseNd4jTest { | ||||
|         assertEquals(e, z); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCreateBufferFromByteBuffer(){ | ||||
| 
 | ||||
|         for(DataType dt : DataType.values()){ | ||||
|             if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN) | ||||
|                 continue; | ||||
| //            System.out.println(dt); | ||||
| 
 | ||||
|             int lengthBytes = 256; | ||||
|             int lengthElements = lengthBytes / dt.width(); | ||||
|             ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes); | ||||
| 
 | ||||
|             DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0); | ||||
|             INDArray arr = Nd4j.create(db, new long[]{lengthElements}); | ||||
| 
 | ||||
|             arr.toStringFull(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public char ordering() { | ||||
|         return 'c'; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user