Fix wrong indexer for some DataBuffer constructors for UINT32 datatype (#458)
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
c396fcb960
commit
a4d74ec4d0
|
@ -1630,7 +1630,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
||||
} else if (t == DataType.UINT32) {
|
||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||
} else if (t == DataType.INT) {
|
||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||
|
@ -1699,6 +1699,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case UINT32:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case INT:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
|
@ -1750,6 +1753,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case UINT32:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case INT:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
|
|
|
@ -411,7 +411,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
||||
} else if (t == DataType.UINT32) {
|
||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||
} else if (t == DataType.INT) {
|
||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||
|
@ -514,7 +514,6 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
attached = true;
|
||||
parentWorkspace = workspace;
|
||||
|
||||
// FIXME: need unsigned indexer here
|
||||
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||
|
||||
|
@ -882,6 +881,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case UINT32:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case INT:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
|
@ -932,6 +934,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case UINT32:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case INT:
|
||||
pointer = nPtr.asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
|
|
|
@ -8411,6 +8411,56 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
|
||||
|
||||
arr.toStringFull();
|
||||
arr.toString();
|
||||
|
||||
for(DataType dt2 : DataType.values()) {
|
||||
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
|
||||
continue;
|
||||
INDArray a2 = arr.castTo(dt2);
|
||||
a2.toStringFull();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateBufferFromByteBufferViews(){
|
||||
|
||||
for(DataType dt : DataType.values()){
|
||||
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||
continue;
|
||||
// System.out.println(dt);
|
||||
|
||||
int lengthBytes = 256;
|
||||
int lengthElements = lengthBytes / dt.width();
|
||||
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
|
||||
|
||||
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
|
||||
INDArray arr = Nd4j.create(db, new long[]{lengthElements/2, 2});
|
||||
|
||||
arr.toStringFull();
|
||||
|
||||
INDArray view = arr.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||
INDArray view2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
|
||||
|
||||
view.toStringFull();
|
||||
view2.toStringFull();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTypeCastingToString(){
|
||||
|
||||
for(DataType dt : DataType.values()) {
|
||||
if (dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||
continue;
|
||||
INDArray a1 = Nd4j.create(dt, 10);
|
||||
for(DataType dt2 : DataType.values()) {
|
||||
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
|
||||
continue;
|
||||
|
||||
INDArray a2 = a1.castTo(dt2);
|
||||
a2.toStringFull();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue