From 1dfac9a7362bdb787ab8a90781f5ad13db65d9b1 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 7 Feb 2020 18:16:11 +0300 Subject: [PATCH] DataBuffer.write() tweak (#221) * special workaround methods for DataBuffer.write Signed-off-by: raver119 * one test removed Signed-off-by: raver119 * more of unsynced Signed-off-by: raver119 * missing asLong for BaseCudaDataBuffer Signed-off-by: raver119 --- .../linalg/api/buffer/BaseDataBuffer.java | 31 +++++++++++-------- .../compression/CompressedDataBuffer.java | 20 ++++++++++++ .../jcublas/buffer/BaseCudaDataBuffer.java | 27 ++++++++++++++++ .../nativecpu/buffer/BaseCpuDataBuffer.java | 20 ++++++++++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 25 --------------- 5 files changed, 85 insertions(+), 38 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 78b12e7fc..12e27e1c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -780,7 +780,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); float[] ret = new float[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getFloat(i); + ret[i] = getFloatUnsynced(i); return ret; } @@ -790,7 +790,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); double[] ret = new double[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getDouble(i); + ret[i] = getDoubleUnsynced(i); return ret; } @@ -800,7 +800,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); int[] ret = new int[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getInt(i); + ret[i] = getIntUnsynced(i); return ret; } @@ -810,7 +810,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); long[] ret = new long[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getLong(i); + ret[i] = getLongUnsynced(i); return ret; } @@ -1662,6 +1662,11 @@ public abstract class BaseDataBuffer implements DataBuffer { } + protected abstract double getDoubleUnsynced(long index); + protected abstract float getFloatUnsynced(long index); + protected abstract long getLongUnsynced(long index); + protected abstract int getIntUnsynced(long index); + @Override public void write(DataOutputStream out) throws IOException { out.writeUTF(allocationMode.name()); @@ -1670,43 +1675,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: for (long i = 0; i < length(); i++) - out.writeDouble(getDouble(i)); + out.writeDouble(getDoubleUnsynced(i)); break; case UINT64: case LONG: for (long i = 0; i < length(); i++) - out.writeLong(getLong(i)); + out.writeLong(getLongUnsynced(i)); break; case UINT32: case INT: for (long i = 0; i < length(); i++) - out.writeInt(getInt(i)); + out.writeInt(getIntUnsynced(i)); break; case UINT16: case SHORT: for (long i = 0; i < length(); i++) - out.writeShort((short) getInt(i)); + out.writeShort((short) getIntUnsynced(i)); break; case UBYTE: case BYTE: for (long i = 0; i < length(); i++) - out.writeByte((byte) getInt(i)); + out.writeByte((byte) getIntUnsynced(i)); break; case BOOL: for (long i = 0; i < length(); i++) - out.writeByte(getInt(i) == 0 ? (byte) 0 : (byte) 1); + out.writeByte(getIntUnsynced(i) == 0 ? (byte) 0 : (byte) 1); break; case BFLOAT16: for (long i = 0; i < length(); i++) - out.writeShort((short) Bfloat16Indexer.fromFloat(getFloat(i))); + out.writeShort((short) Bfloat16Indexer.fromFloat(getFloatUnsynced(i))); break; case HALF: for (long i = 0; i < length(); i++) - out.writeShort((short) HalfIndexer.fromFloat(getFloat(i))); + out.writeShort((short) HalfIndexer.fromFloat(getFloatUnsynced(i))); break; case FLOAT: for (long i = 0; i < length(); i++) - out.writeFloat(getFloat(i)); + out.writeFloat(getFloatUnsynced(i)); break; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java index 0c822ce0a..f1c9ed6d9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java @@ -210,4 +210,24 @@ public class CompressedDataBuffer extends BaseDataBuffer { public DataBuffer reallocate(long length) { throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); } + + @Override + protected double getDoubleUnsynced(long index) { + return super.getDouble(index); + } + + @Override + protected float getFloatUnsynced(long index) { + return super.getFloat(index); + } + + @Override + protected long getLongUnsynced(long index) { + return super.getLong(index); + } + + @Override + protected int getIntUnsynced(long index) { + return super.getInt(index); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index cdec4e1be..2f1cab334 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1287,6 +1287,26 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public void destroy() {} + @Override + protected double getDoubleUnsynced(long index) { + return super.getDouble(index); + } + + @Override + protected float getFloatUnsynced(long index) { + return super.getFloat(index); + } + + @Override + protected long getLongUnsynced(long index) { + return super.getLong(index); + } + + @Override + protected int getIntUnsynced(long index) { + return super.getInt(index); + } + @Override public void write(DataOutputStream out) throws IOException { lazyAllocateHostPointer(); @@ -1510,6 +1530,13 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda return super.asInt(); } + @Override + public long[] asLong() { + lazyAllocateHostPointer(); + allocator.synchronizeHostData(this); + return super.asLong(); + } + @Override public ByteBuffer asNio() { lazyAllocateHostPointer(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java index a51666f78..a5ddc7aef 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -208,6 +208,26 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype)); } + @Override + protected double getDoubleUnsynced(long index) { + return super.getDouble(index); + } + + @Override + protected float getFloatUnsynced(long index) { + return super.getFloat(index); + } + + @Override + protected long getLongUnsynced(long index) { + return super.getLong(index); + } + + @Override + protected int getIntUnsynced(long index) { + return super.getInt(index); + } + @Override public void pointerIndexerByCurrentType(DataType currentType) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index ad5bacc4e..d96c0ed31 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8262,31 +8262,6 @@ public class Nd4jTestsC extends BaseNd4jTest { assertArrayEquals(new long[]{10, 0}, out2.shape()); } - @Test - public void testDealloc_1() throws Exception { - - for (int e = 0; e < 5000; e++){ - try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) { - val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000); - //val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10); - //val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10); - //val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10); - } finally { - //System.gc(); - } - } - - Thread.sleep(1000); - System.gc(); - - Thread.sleep(1000); - System.gc(); - System.gc(); - System.gc(); - - //Nd4j.getMemoryManager().printRemainingStacks(); - } - @Override public char ordering() { return 'c';