DataBuffer.write() tweak (#221)
* special workaround methods for DataBuffer.write Signed-off-by: raver119 <raver119@gmail.com> * one test removed Signed-off-by: raver119 <raver119@gmail.com> * more of unsynced Signed-off-by: raver119 <raver119@gmail.com> * missing asLong for BaseCudaDataBuffer Signed-off-by: raver119 <raver119@gmail.com>master
parent
a0da5a9e47
commit
1dfac9a736
|
@ -780,7 +780,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
float[] ret = new float[(int) length];
|
float[] ret = new float[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getFloat(i);
|
ret[i] = getFloatUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -790,7 +790,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
double[] ret = new double[(int) length];
|
double[] ret = new double[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getDouble(i);
|
ret[i] = getDoubleUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -800,7 +800,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
int[] ret = new int[(int) length];
|
int[] ret = new int[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getInt(i);
|
ret[i] = getIntUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -810,7 +810,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
long[] ret = new long[(int) length];
|
long[] ret = new long[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getLong(i);
|
ret[i] = getLongUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1662,6 +1662,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected abstract double getDoubleUnsynced(long index);
|
||||||
|
protected abstract float getFloatUnsynced(long index);
|
||||||
|
protected abstract long getLongUnsynced(long index);
|
||||||
|
protected abstract int getIntUnsynced(long index);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void write(DataOutputStream out) throws IOException {
|
public void write(DataOutputStream out) throws IOException {
|
||||||
out.writeUTF(allocationMode.name());
|
out.writeUTF(allocationMode.name());
|
||||||
|
@ -1670,43 +1675,43 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeDouble(getDouble(i));
|
out.writeDouble(getDoubleUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UINT64:
|
case UINT64:
|
||||||
case LONG:
|
case LONG:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeLong(getLong(i));
|
out.writeLong(getLongUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
case INT:
|
case INT:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeInt(getInt(i));
|
out.writeInt(getIntUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UINT16:
|
case UINT16:
|
||||||
case SHORT:
|
case SHORT:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeShort((short) getInt(i));
|
out.writeShort((short) getIntUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
case BYTE:
|
case BYTE:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeByte((byte) getInt(i));
|
out.writeByte((byte) getIntUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case BOOL:
|
case BOOL:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeByte(getInt(i) == 0 ? (byte) 0 : (byte) 1);
|
out.writeByte(getIntUnsynced(i) == 0 ? (byte) 0 : (byte) 1);
|
||||||
break;
|
break;
|
||||||
case BFLOAT16:
|
case BFLOAT16:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeShort((short) Bfloat16Indexer.fromFloat(getFloat(i)));
|
out.writeShort((short) Bfloat16Indexer.fromFloat(getFloatUnsynced(i)));
|
||||||
break;
|
break;
|
||||||
case HALF:
|
case HALF:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeShort((short) HalfIndexer.fromFloat(getFloat(i)));
|
out.writeShort((short) HalfIndexer.fromFloat(getFloatUnsynced(i)));
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeFloat(getFloat(i));
|
out.writeFloat(getFloatUnsynced(i));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -210,4 +210,24 @@ public class CompressedDataBuffer extends BaseDataBuffer {
|
||||||
public DataBuffer reallocate(long length) {
|
public DataBuffer reallocate(long length) {
|
||||||
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected double getDoubleUnsynced(long index) {
|
||||||
|
return super.getDouble(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected float getFloatUnsynced(long index) {
|
||||||
|
return super.getFloat(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected long getLongUnsynced(long index) {
|
||||||
|
return super.getLong(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getIntUnsynced(long index) {
|
||||||
|
return super.getInt(index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1287,6 +1287,26 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
@Override
|
@Override
|
||||||
public void destroy() {}
|
public void destroy() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected double getDoubleUnsynced(long index) {
|
||||||
|
return super.getDouble(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected float getFloatUnsynced(long index) {
|
||||||
|
return super.getFloat(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected long getLongUnsynced(long index) {
|
||||||
|
return super.getLong(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getIntUnsynced(long index) {
|
||||||
|
return super.getInt(index);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void write(DataOutputStream out) throws IOException {
|
public void write(DataOutputStream out) throws IOException {
|
||||||
lazyAllocateHostPointer();
|
lazyAllocateHostPointer();
|
||||||
|
@ -1510,6 +1530,13 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
return super.asInt();
|
return super.asInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long[] asLong() {
|
||||||
|
lazyAllocateHostPointer();
|
||||||
|
allocator.synchronizeHostData(this);
|
||||||
|
return super.asLong();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ByteBuffer asNio() {
|
public ByteBuffer asNio() {
|
||||||
lazyAllocateHostPointer();
|
lazyAllocateHostPointer();
|
||||||
|
|
|
@ -208,6 +208,26 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype));
|
Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected double getDoubleUnsynced(long index) {
|
||||||
|
return super.getDouble(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected float getFloatUnsynced(long index) {
|
||||||
|
return super.getFloat(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected long getLongUnsynced(long index) {
|
||||||
|
return super.getLong(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getIntUnsynced(long index) {
|
||||||
|
return super.getInt(index);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void pointerIndexerByCurrentType(DataType currentType) {
|
public void pointerIndexerByCurrentType(DataType currentType) {
|
||||||
|
|
||||||
|
|
|
@ -8262,31 +8262,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDealloc_1() throws Exception {
|
|
||||||
|
|
||||||
for (int e = 0; e < 5000; e++){
|
|
||||||
try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) {
|
|
||||||
val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000);
|
|
||||||
//val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10);
|
|
||||||
//val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10);
|
|
||||||
//val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10);
|
|
||||||
} finally {
|
|
||||||
//System.gc();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Thread.sleep(1000);
|
|
||||||
System.gc();
|
|
||||||
|
|
||||||
Thread.sleep(1000);
|
|
||||||
System.gc();
|
|
||||||
System.gc();
|
|
||||||
System.gc();
|
|
||||||
|
|
||||||
//Nd4j.getMemoryManager().printRemainingStacks();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue