DataBuffer.write() tweak (#221)
* special workaround methods for DataBuffer.write Signed-off-by: raver119 <raver119@gmail.com> * one test removed Signed-off-by: raver119 <raver119@gmail.com> * more of unsynced Signed-off-by: raver119 <raver119@gmail.com> * missing asLong for BaseCudaDataBuffer Signed-off-by: raver119 <raver119@gmail.com>master
parent
a0da5a9e47
commit
1dfac9a736
|
@ -780,7 +780,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||
float[] ret = new float[(int) length];
|
||||
for (int i = 0; i < length; i++)
|
||||
ret[i] = getFloat(i);
|
||||
ret[i] = getFloatUnsynced(i);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -790,7 +790,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||
double[] ret = new double[(int) length];
|
||||
for (int i = 0; i < length; i++)
|
||||
ret[i] = getDouble(i);
|
||||
ret[i] = getDoubleUnsynced(i);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -800,7 +800,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||
int[] ret = new int[(int) length];
|
||||
for (int i = 0; i < length; i++)
|
||||
ret[i] = getInt(i);
|
||||
ret[i] = getIntUnsynced(i);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -810,7 +810,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||
long[] ret = new long[(int) length];
|
||||
for (int i = 0; i < length; i++)
|
||||
ret[i] = getLong(i);
|
||||
ret[i] = getLongUnsynced(i);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -1662,6 +1662,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
}
|
||||
|
||||
protected abstract double getDoubleUnsynced(long index);
|
||||
protected abstract float getFloatUnsynced(long index);
|
||||
protected abstract long getLongUnsynced(long index);
|
||||
protected abstract int getIntUnsynced(long index);
|
||||
|
||||
@Override
|
||||
public void write(DataOutputStream out) throws IOException {
|
||||
out.writeUTF(allocationMode.name());
|
||||
|
@ -1670,43 +1675,43 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
switch (dataType()) {
|
||||
case DOUBLE:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeDouble(getDouble(i));
|
||||
out.writeDouble(getDoubleUnsynced(i));
|
||||
break;
|
||||
case UINT64:
|
||||
case LONG:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeLong(getLong(i));
|
||||
out.writeLong(getLongUnsynced(i));
|
||||
break;
|
||||
case UINT32:
|
||||
case INT:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeInt(getInt(i));
|
||||
out.writeInt(getIntUnsynced(i));
|
||||
break;
|
||||
case UINT16:
|
||||
case SHORT:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeShort((short) getInt(i));
|
||||
out.writeShort((short) getIntUnsynced(i));
|
||||
break;
|
||||
case UBYTE:
|
||||
case BYTE:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeByte((byte) getInt(i));
|
||||
out.writeByte((byte) getIntUnsynced(i));
|
||||
break;
|
||||
case BOOL:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeByte(getInt(i) == 0 ? (byte) 0 : (byte) 1);
|
||||
out.writeByte(getIntUnsynced(i) == 0 ? (byte) 0 : (byte) 1);
|
||||
break;
|
||||
case BFLOAT16:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeShort((short) Bfloat16Indexer.fromFloat(getFloat(i)));
|
||||
out.writeShort((short) Bfloat16Indexer.fromFloat(getFloatUnsynced(i)));
|
||||
break;
|
||||
case HALF:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeShort((short) HalfIndexer.fromFloat(getFloat(i)));
|
||||
out.writeShort((short) HalfIndexer.fromFloat(getFloatUnsynced(i)));
|
||||
break;
|
||||
case FLOAT:
|
||||
for (long i = 0; i < length(); i++)
|
||||
out.writeFloat(getFloat(i));
|
||||
out.writeFloat(getFloatUnsynced(i));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -210,4 +210,24 @@ public class CompressedDataBuffer extends BaseDataBuffer {
|
|||
public DataBuffer reallocate(long length) {
|
||||
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double getDoubleUnsynced(long index) {
|
||||
return super.getDouble(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float getFloatUnsynced(long index) {
|
||||
return super.getFloat(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected long getLongUnsynced(long index) {
|
||||
return super.getLong(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getIntUnsynced(long index) {
|
||||
return super.getInt(index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1287,6 +1287,26 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
@Override
|
||||
public void destroy() {}
|
||||
|
||||
@Override
|
||||
protected double getDoubleUnsynced(long index) {
|
||||
return super.getDouble(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float getFloatUnsynced(long index) {
|
||||
return super.getFloat(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected long getLongUnsynced(long index) {
|
||||
return super.getLong(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getIntUnsynced(long index) {
|
||||
return super.getInt(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void write(DataOutputStream out) throws IOException {
|
||||
lazyAllocateHostPointer();
|
||||
|
@ -1510,6 +1530,13 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
return super.asInt();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] asLong() {
|
||||
lazyAllocateHostPointer();
|
||||
allocator.synchronizeHostData(this);
|
||||
return super.asLong();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteBuffer asNio() {
|
||||
lazyAllocateHostPointer();
|
||||
|
|
|
@ -208,6 +208,26 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double getDoubleUnsynced(long index) {
|
||||
return super.getDouble(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float getFloatUnsynced(long index) {
|
||||
return super.getFloat(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected long getLongUnsynced(long index) {
|
||||
return super.getLong(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getIntUnsynced(long index) {
|
||||
return super.getInt(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void pointerIndexerByCurrentType(DataType currentType) {
|
||||
|
||||
|
|
|
@ -8262,31 +8262,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDealloc_1() throws Exception {
|
||||
|
||||
for (int e = 0; e < 5000; e++){
|
||||
try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) {
|
||||
val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000);
|
||||
//val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10);
|
||||
//val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10);
|
||||
//val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10);
|
||||
} finally {
|
||||
//System.gc();
|
||||
}
|
||||
}
|
||||
|
||||
Thread.sleep(1000);
|
||||
System.gc();
|
||||
|
||||
Thread.sleep(1000);
|
||||
System.gc();
|
||||
System.gc();
|
||||
System.gc();
|
||||
|
||||
//Nd4j.getMemoryManager().printRemainingStacks();
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue