Fix wrong indexer for some DataBuffer constructors for UINT32 datatype (#458)
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
c396fcb960
commit
a4d74ec4d0
|
@ -1630,7 +1630,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
||||||
} else if (t == DataType.UINT32) {
|
} else if (t == DataType.UINT32) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
} else if (t == DataType.INT) {
|
} else if (t == DataType.INT) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||||
|
@ -1699,6 +1699,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
@ -1750,6 +1753,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
|
|
@ -411,7 +411,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
||||||
} else if (t == DataType.UINT32) {
|
} else if (t == DataType.UINT32) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
} else if (t == DataType.INT) {
|
} else if (t == DataType.INT) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||||
|
@ -514,7 +514,6 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
attached = true;
|
attached = true;
|
||||||
parentWorkspace = workspace;
|
parentWorkspace = workspace;
|
||||||
|
|
||||||
// FIXME: need unsigned indexer here
|
|
||||||
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
||||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
|
|
||||||
|
@ -882,6 +881,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
@ -932,6 +934,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
|
|
@ -8411,6 +8411,56 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
|
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
|
||||||
|
|
||||||
arr.toStringFull();
|
arr.toStringFull();
|
||||||
|
arr.toString();
|
||||||
|
|
||||||
|
for(DataType dt2 : DataType.values()) {
|
||||||
|
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
INDArray a2 = arr.castTo(dt2);
|
||||||
|
a2.toStringFull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateBufferFromByteBufferViews(){
|
||||||
|
|
||||||
|
for(DataType dt : DataType.values()){
|
||||||
|
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
// System.out.println(dt);
|
||||||
|
|
||||||
|
int lengthBytes = 256;
|
||||||
|
int lengthElements = lengthBytes / dt.width();
|
||||||
|
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
|
||||||
|
|
||||||
|
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
|
||||||
|
INDArray arr = Nd4j.create(db, new long[]{lengthElements/2, 2});
|
||||||
|
|
||||||
|
arr.toStringFull();
|
||||||
|
|
||||||
|
INDArray view = arr.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||||
|
INDArray view2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
|
||||||
|
|
||||||
|
view.toStringFull();
|
||||||
|
view2.toStringFull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTypeCastingToString(){
|
||||||
|
|
||||||
|
for(DataType dt : DataType.values()) {
|
||||||
|
if (dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
INDArray a1 = Nd4j.create(dt, 10);
|
||||||
|
for(DataType dt2 : DataType.values()) {
|
||||||
|
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
INDArray a2 = a1.castTo(dt2);
|
||||||
|
a2.toStringFull();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue