From a4d74ec4d0168dc4ed6dd3c68cddac2f5e7dc7f2 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 13 May 2020 19:47:51 +1000 Subject: [PATCH] Fix wrong indexer for some DataBuffer constructors for UINT32 datatype (#458) Signed-off-by: Alex Black --- .../jcublas/buffer/BaseCudaDataBuffer.java | 8 ++- .../nativecpu/buffer/BaseCpuDataBuffer.java | 9 +++- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 50 +++++++++++++++++++ 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 14e64df61..1da7e1c36 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1630,7 +1630,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda setIndexer(ShortIndexer.create((ShortPointer) pointer)); } else if (t == DataType.UINT32) { pointer = new PagedPointer(cptr, length).asIntPointer(); - setIndexer(IntIndexer.create((IntPointer) pointer)); + setIndexer(UIntIndexer.create((IntPointer) pointer)); } else if (t == DataType.INT) { pointer = new PagedPointer(cptr, length).asIntPointer(); setIndexer(IntIndexer.create((IntPointer) pointer)); @@ -1699,6 +1699,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda indexer = ShortIndexer.create((ShortPointer) pointer); break; case UINT32: + pointer = nPtr.asIntPointer(); + indexer = UIntIndexer.create((IntPointer) pointer); + break; case INT: pointer = nPtr.asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); @@ -1750,6 +1753,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda indexer = ShortIndexer.create((ShortPointer) pointer); break; case UINT32: + pointer = nPtr.asIntPointer(); + indexer = UIntIndexer.create((IntPointer) pointer); + break; case INT: pointer = nPtr.asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java index 7a2a8467a..6666068da 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -411,7 +411,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo setIndexer(ShortIndexer.create((ShortPointer) pointer)); } else if (t == DataType.UINT32) { pointer = new PagedPointer(cptr, length).asIntPointer(); - setIndexer(IntIndexer.create((IntPointer) pointer)); + setIndexer(UIntIndexer.create((IntPointer) pointer)); } else if (t == DataType.INT) { pointer = new PagedPointer(cptr, length).asIntPointer(); setIndexer(IntIndexer.create((IntPointer) pointer)); @@ -514,7 +514,6 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo attached = true; parentWorkspace = workspace; - // FIXME: need unsigned indexer here pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length()); setIndexer(UIntIndexer.create((IntPointer) pointer)); @@ -882,6 +881,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo indexer = ShortIndexer.create((ShortPointer) pointer); break; case UINT32: + pointer = nPtr.asIntPointer(); + indexer = UIntIndexer.create((IntPointer) pointer); + break; case INT: pointer = nPtr.asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); @@ -932,6 +934,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo indexer = ShortIndexer.create((ShortPointer) pointer); break; case UINT32: + pointer = nPtr.asIntPointer(); + indexer = UIntIndexer.create((IntPointer) pointer); + break; case INT: pointer = nPtr.asIntPointer(); indexer = IntIndexer.create((IntPointer) pointer); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 46f47017e..a70ede362 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8411,6 +8411,56 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr = Nd4j.create(db, new long[]{lengthElements}); arr.toStringFull(); + arr.toString(); + + for(DataType dt2 : DataType.values()) { + if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN) + continue; + INDArray a2 = arr.castTo(dt2); + a2.toStringFull(); + } + } + } + + @Test + public void testCreateBufferFromByteBufferViews(){ + + for(DataType dt : DataType.values()){ + if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN) + continue; +// System.out.println(dt); + + int lengthBytes = 256; + int lengthElements = lengthBytes / dt.width(); + ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes); + + DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0); + INDArray arr = Nd4j.create(db, new long[]{lengthElements/2, 2}); + + arr.toStringFull(); + + INDArray view = arr.get(NDArrayIndex.all(), NDArrayIndex.point(0)); + INDArray view2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all()); + + view.toStringFull(); + view2.toStringFull(); + } + } + + @Test + public void testTypeCastingToString(){ + + for(DataType dt : DataType.values()) { + if (dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN) + continue; + INDArray a1 = Nd4j.create(dt, 10); + for(DataType dt2 : DataType.values()) { + if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN) + continue; + + INDArray a2 = a1.castTo(dt2); + a2.toStringFull(); + } } }