Fix wrong indexer for some DataBuffer constructors for UINT32 datatype (#458)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-05-13 19:47:51 +10:00 committed by GitHub
parent c396fcb960
commit a4d74ec4d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 3 deletions

View File

@ -1630,7 +1630,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
setIndexer(ShortIndexer.create((ShortPointer) pointer));
} else if (t == DataType.UINT32) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (t == DataType.INT) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
@ -1699,6 +1699,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
@ -1750,6 +1753,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);

View File

@ -411,7 +411,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
setIndexer(ShortIndexer.create((ShortPointer) pointer));
} else if (t == DataType.UINT32) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
setIndexer(UIntIndexer.create((IntPointer) pointer));
} else if (t == DataType.INT) {
pointer = new PagedPointer(cptr, length).asIntPointer();
setIndexer(IntIndexer.create((IntPointer) pointer));
@ -514,7 +514,6 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
attached = true;
parentWorkspace = workspace;
// FIXME: need unsigned indexer here
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
setIndexer(UIntIndexer.create((IntPointer) pointer));
@ -882,6 +881,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
@ -932,6 +934,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT32:
pointer = nPtr.asIntPointer();
indexer = UIntIndexer.create((IntPointer) pointer);
break;
case INT:
pointer = nPtr.asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);

View File

@ -8411,6 +8411,56 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
arr.toStringFull();
arr.toString();
for(DataType dt2 : DataType.values()) {
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
continue;
INDArray a2 = arr.castTo(dt2);
a2.toStringFull();
}
}
}
@Test
public void testCreateBufferFromByteBufferViews(){
for(DataType dt : DataType.values()){
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
continue;
// System.out.println(dt);
int lengthBytes = 256;
int lengthElements = lengthBytes / dt.width();
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
INDArray arr = Nd4j.create(db, new long[]{lengthElements/2, 2});
arr.toStringFull();
INDArray view = arr.get(NDArrayIndex.all(), NDArrayIndex.point(0));
INDArray view2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
view.toStringFull();
view2.toStringFull();
}
}
@Test
public void testTypeCastingToString(){
for(DataType dt : DataType.values()) {
if (dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
continue;
INDArray a1 = Nd4j.create(dt, 10);
for(DataType dt2 : DataType.values()) {
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
continue;
INDArray a2 = a1.castTo(dt2);
a2.toStringFull();
}
}
}