DataBuffer.write() tweak (#221)

* special workaround methods for DataBuffer.write

Signed-off-by: raver119 <raver119@gmail.com>

* one test removed

Signed-off-by: raver119 <raver119@gmail.com>

* more of unsynced

Signed-off-by: raver119 <raver119@gmail.com>

* missing asLong for BaseCudaDataBuffer

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-07 18:16:11 +03:00 committed by GitHub
parent a0da5a9e47
commit 1dfac9a736
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 38 deletions

View File

@ -780,7 +780,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length); throw new IllegalArgumentException("Unable to create array of length " + length);
float[] ret = new float[(int) length]; float[] ret = new float[(int) length];
for (int i = 0; i < length; i++) for (int i = 0; i < length; i++)
ret[i] = getFloat(i); ret[i] = getFloatUnsynced(i);
return ret; return ret;
} }
@ -790,7 +790,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length); throw new IllegalArgumentException("Unable to create array of length " + length);
double[] ret = new double[(int) length]; double[] ret = new double[(int) length];
for (int i = 0; i < length; i++) for (int i = 0; i < length; i++)
ret[i] = getDouble(i); ret[i] = getDoubleUnsynced(i);
return ret; return ret;
} }
@ -800,7 +800,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length); throw new IllegalArgumentException("Unable to create array of length " + length);
int[] ret = new int[(int) length]; int[] ret = new int[(int) length];
for (int i = 0; i < length; i++) for (int i = 0; i < length; i++)
ret[i] = getInt(i); ret[i] = getIntUnsynced(i);
return ret; return ret;
} }
@ -810,7 +810,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
throw new IllegalArgumentException("Unable to create array of length " + length); throw new IllegalArgumentException("Unable to create array of length " + length);
long[] ret = new long[(int) length]; long[] ret = new long[(int) length];
for (int i = 0; i < length; i++) for (int i = 0; i < length; i++)
ret[i] = getLong(i); ret[i] = getLongUnsynced(i);
return ret; return ret;
} }
@ -1662,6 +1662,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
} }
protected abstract double getDoubleUnsynced(long index);
protected abstract float getFloatUnsynced(long index);
protected abstract long getLongUnsynced(long index);
protected abstract int getIntUnsynced(long index);
@Override @Override
public void write(DataOutputStream out) throws IOException { public void write(DataOutputStream out) throws IOException {
out.writeUTF(allocationMode.name()); out.writeUTF(allocationMode.name());
@ -1670,43 +1675,43 @@ public abstract class BaseDataBuffer implements DataBuffer {
switch (dataType()) { switch (dataType()) {
case DOUBLE: case DOUBLE:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeDouble(getDouble(i)); out.writeDouble(getDoubleUnsynced(i));
break; break;
case UINT64: case UINT64:
case LONG: case LONG:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeLong(getLong(i)); out.writeLong(getLongUnsynced(i));
break; break;
case UINT32: case UINT32:
case INT: case INT:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeInt(getInt(i)); out.writeInt(getIntUnsynced(i));
break; break;
case UINT16: case UINT16:
case SHORT: case SHORT:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeShort((short) getInt(i)); out.writeShort((short) getIntUnsynced(i));
break; break;
case UBYTE: case UBYTE:
case BYTE: case BYTE:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeByte((byte) getInt(i)); out.writeByte((byte) getIntUnsynced(i));
break; break;
case BOOL: case BOOL:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeByte(getInt(i) == 0 ? (byte) 0 : (byte) 1); out.writeByte(getIntUnsynced(i) == 0 ? (byte) 0 : (byte) 1);
break; break;
case BFLOAT16: case BFLOAT16:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeShort((short) Bfloat16Indexer.fromFloat(getFloat(i))); out.writeShort((short) Bfloat16Indexer.fromFloat(getFloatUnsynced(i)));
break; break;
case HALF: case HALF:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeShort((short) HalfIndexer.fromFloat(getFloat(i))); out.writeShort((short) HalfIndexer.fromFloat(getFloatUnsynced(i)));
break; break;
case FLOAT: case FLOAT:
for (long i = 0; i < length(); i++) for (long i = 0; i < length(); i++)
out.writeFloat(getFloat(i)); out.writeFloat(getFloatUnsynced(i));
break; break;
} }
} }

View File

@ -210,4 +210,24 @@ public class CompressedDataBuffer extends BaseDataBuffer {
public DataBuffer reallocate(long length) { public DataBuffer reallocate(long length) {
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
} }
@Override
protected double getDoubleUnsynced(long index) {
return super.getDouble(index);
}
@Override
protected float getFloatUnsynced(long index) {
return super.getFloat(index);
}
@Override
protected long getLongUnsynced(long index) {
return super.getLong(index);
}
@Override
protected int getIntUnsynced(long index) {
return super.getInt(index);
}
} }

View File

@ -1287,6 +1287,26 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override @Override
public void destroy() {} public void destroy() {}
@Override
protected double getDoubleUnsynced(long index) {
return super.getDouble(index);
}
@Override
protected float getFloatUnsynced(long index) {
return super.getFloat(index);
}
@Override
protected long getLongUnsynced(long index) {
return super.getLong(index);
}
@Override
protected int getIntUnsynced(long index) {
return super.getInt(index);
}
@Override @Override
public void write(DataOutputStream out) throws IOException { public void write(DataOutputStream out) throws IOException {
lazyAllocateHostPointer(); lazyAllocateHostPointer();
@ -1510,6 +1530,13 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
return super.asInt(); return super.asInt();
} }
@Override
public long[] asLong() {
lazyAllocateHostPointer();
allocator.synchronizeHostData(this);
return super.asLong();
}
@Override @Override
public ByteBuffer asNio() { public ByteBuffer asNio() {
lazyAllocateHostPointer(); lazyAllocateHostPointer();

View File

@ -208,6 +208,26 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype)); Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype));
} }
@Override
protected double getDoubleUnsynced(long index) {
return super.getDouble(index);
}
@Override
protected float getFloatUnsynced(long index) {
return super.getFloat(index);
}
@Override
protected long getLongUnsynced(long index) {
return super.getLong(index);
}
@Override
protected int getIntUnsynced(long index) {
return super.getInt(index);
}
@Override @Override
public void pointerIndexerByCurrentType(DataType currentType) { public void pointerIndexerByCurrentType(DataType currentType) {

View File

@ -8262,31 +8262,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertArrayEquals(new long[]{10, 0}, out2.shape()); assertArrayEquals(new long[]{10, 0}, out2.shape());
} }
@Test
public void testDealloc_1() throws Exception {
for (int e = 0; e < 5000; e++){
try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) {
val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000);
//val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10);
//val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10);
//val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10);
} finally {
//System.gc();
}
}
Thread.sleep(1000);
System.gc();
Thread.sleep(1000);
System.gc();
System.gc();
System.gc();
//Nd4j.getMemoryManager().printRemainingStacks();
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';