J9+ -> J8 ByteBuffer fix (#59)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-20 07:43:17 +03:00 committed by GitHub
parent 630409cd53
commit 3f38900c33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 19 additions and 13 deletions

View File

@ -21,6 +21,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import java.io.Serializable; import java.io.Serializable;
import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
@ -142,7 +143,7 @@ public class CompressionDescriptor implements Cloneable, Serializable {
directAlloc.putLong(numberOfElements); directAlloc.putLong(numberOfElements);
directAlloc.putLong(originalElementSize); directAlloc.putLong(originalElementSize);
directAlloc.putInt(originalDataType.ordinal()); directAlloc.putInt(originalDataType.ordinal());
directAlloc.rewind(); ((Buffer) directAlloc).rewind();
return directAlloc; return directAlloc;
} }

View File

@ -93,6 +93,7 @@ import org.nd4j.versioncheck.VersionCheck;
import java.io.*; import java.io.*;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.Channels; import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel; import java.nio.channels.WritableByteChannel;
@ -5681,7 +5682,7 @@ public class Nd4j {
public static INDArray createNpyFromByteArray(@NonNull byte[] input) { public static INDArray createNpyFromByteArray(@NonNull byte[] input) {
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length); ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length);
byteBuffer.put(input); byteBuffer.put(input);
byteBuffer.rewind(); ((Buffer) byteBuffer).rewind();
Pointer pointer = new Pointer(byteBuffer); Pointer pointer = new Pointer(byteBuffer);
return createFromNpyPointer(pointer); return createFromNpyPointer(pointer);
} }

View File

@ -31,6 +31,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import java.io.*; import java.io.*;
import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.nio.channels.Channels; import java.nio.channels.Channels;
@ -215,7 +216,7 @@ public class BinarySerde {
allocated.put(shapeBuffer); allocated.put(shapeBuffer);
allocated.put(buffer); allocated.put(buffer);
if (rewind) if (rewind)
allocated.rewind(); ((Buffer) allocated).rewind();
} }
/** /**
@ -247,7 +248,7 @@ public class BinarySerde {
//finally put the data //finally put the data
allocated.put(buffer); allocated.put(buffer);
if (rewind) if (rewind)
allocated.rewind(); ((Buffer) allocated).rewind();
} }
@ -314,7 +315,7 @@ public class BinarySerde {
ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder()); ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
buffer.position(0); ((Buffer) buffer).position(0);
int rank = byteBuffer.getInt(); int rank = byteBuffer.getInt();
val result = new long[Shape.shapeInfoLength(rank)]; val result = new long[Shape.shapeInfoLength(rank)];
@ -324,7 +325,7 @@ public class BinarySerde {
// skipping two next values (dtype and rank again) // skipping two next values (dtype and rank again)
// please , that this time rank has dtype of LONG, so takes 8 bytes. // please , that this time rank has dtype of LONG, so takes 8 bytes.
byteBuffer.position(16); ((Buffer) byteBuffer).position(16);
// filling shape information // filling shape information
for (int e = 1; e < Shape.shapeInfoLength(rank); e++) { for (int e = 1; e < Shape.shapeInfoLength(rank); e++) {

View File

@ -36,6 +36,7 @@ import org.nd4j.linalg.util.ArrayUtil;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.InputStream; import java.io.InputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.nio.charset.Charset; import java.nio.charset.Charset;
@ -492,8 +493,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8")); byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder()); ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
directBuffer.put(pathBytes); directBuffer.put(pathBytes);
directBuffer.rewind(); ((Buffer) directBuffer).rewind();
directBuffer.position(0); ((Buffer) directBuffer).position(0);
Pointer pointer = nativeOps.numpyFromFile(new BytePointer(directBuffer)); Pointer pointer = nativeOps.numpyFromFile(new BytePointer(directBuffer));
INDArray result = createFromNpyPointer(pointer); INDArray result = createFromNpyPointer(pointer);
@ -672,8 +673,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8")); byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder()); ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
directBuffer.put(pathBytes); directBuffer.put(pathBytes);
directBuffer.rewind(); ((Buffer) directBuffer).rewind();
directBuffer.position(0); ((Buffer) directBuffer).position(0);
Pointer pointer = nativeOps.mapFromNpzFile(new BytePointer(directBuffer)); Pointer pointer = nativeOps.mapFromNpzFile(new BytePointer(directBuffer));
int n = nativeOps.getNumNpyArraysInMap(pointer); int n = nativeOps.getNumNpyArraysInMap(pointer);
HashMap<String, INDArray> map = new HashMap<>(); HashMap<String, INDArray> map = new HashMap<>();

View File

@ -32,6 +32,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.Closeable; import java.io.Closeable;
import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
/** /**
@ -129,7 +130,7 @@ public class AeronNDArrayPublisher implements AutoCloseable {
NDArrayMessageChunk[] chunks = NDArrayMessage.chunks(message, publication.maxMessageLength() / 128); NDArrayMessageChunk[] chunks = NDArrayMessage.chunks(message, publication.maxMessageLength() / 128);
for (int i = 0; i < chunks.length; i++) { for (int i = 0; i < chunks.length; i++) {
ByteBuffer sendBuff = NDArrayMessageChunk.toBuffer(chunks[i]); ByteBuffer sendBuff = NDArrayMessageChunk.toBuffer(chunks[i]);
sendBuff.rewind(); ((Buffer) sendBuff).rewind();
DirectBuffer buffer = new UnsafeBuffer(sendBuff); DirectBuffer buffer = new UnsafeBuffer(sendBuff);
sendBuffer(buffer); sendBuffer(buffer);
} }

View File

@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable; import java.io.Serializable;
import java.nio.Buffer;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.time.Instant; import java.time.Instant;
@ -229,7 +230,7 @@ public class NDArrayMessage implements Serializable {
for (int i = 0; i < chunks.length; i++) { for (int i = 0; i < chunks.length; i++) {
ByteBuffer curr = chunks[i].getData(); ByteBuffer curr = chunks[i].getData();
if (curr.capacity() > chunks[0].getChunkSize()) { if (curr.capacity() > chunks[0].getChunkSize()) {
curr.position(0).limit(chunks[0].getChunkSize()); ((Buffer) curr).position(0).limit(chunks[0].getChunkSize());
curr = curr.slice(); curr = curr.slice();
} }
all.put(curr); all.put(curr);
@ -311,7 +312,7 @@ public class NDArrayMessage implements Serializable {
//rewind the buffer before putting it in to the unsafe buffer //rewind the buffer before putting it in to the unsafe buffer
//note that we set rewind to false in the do byte buffer put methods //note that we set rewind to false in the do byte buffer put methods
byteBuffer.rewind(); ((Buffer) byteBuffer).rewind();
return new UnsafeBuffer(byteBuffer); return new UnsafeBuffer(byteBuffer);
} }