parent
630409cd53
commit
3f38900c33
|
@ -21,6 +21,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
|
|
||||||
|
@ -142,7 +143,7 @@ public class CompressionDescriptor implements Cloneable, Serializable {
|
||||||
directAlloc.putLong(numberOfElements);
|
directAlloc.putLong(numberOfElements);
|
||||||
directAlloc.putLong(originalElementSize);
|
directAlloc.putLong(originalElementSize);
|
||||||
directAlloc.putInt(originalDataType.ordinal());
|
directAlloc.putInt(originalDataType.ordinal());
|
||||||
directAlloc.rewind();
|
((Buffer) directAlloc).rewind();
|
||||||
return directAlloc;
|
return directAlloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -93,6 +93,7 @@ import org.nd4j.versioncheck.VersionCheck;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.reflect.Constructor;
|
import java.lang.reflect.Constructor;
|
||||||
import java.math.BigDecimal;
|
import java.math.BigDecimal;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.channels.Channels;
|
import java.nio.channels.Channels;
|
||||||
import java.nio.channels.WritableByteChannel;
|
import java.nio.channels.WritableByteChannel;
|
||||||
|
@ -5681,7 +5682,7 @@ public class Nd4j {
|
||||||
public static INDArray createNpyFromByteArray(@NonNull byte[] input) {
|
public static INDArray createNpyFromByteArray(@NonNull byte[] input) {
|
||||||
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length);
|
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length);
|
||||||
byteBuffer.put(input);
|
byteBuffer.put(input);
|
||||||
byteBuffer.rewind();
|
((Buffer) byteBuffer).rewind();
|
||||||
Pointer pointer = new Pointer(byteBuffer);
|
Pointer pointer = new Pointer(byteBuffer);
|
||||||
return createFromNpyPointer(pointer);
|
return createFromNpyPointer(pointer);
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.nio.channels.Channels;
|
import java.nio.channels.Channels;
|
||||||
|
@ -215,7 +216,7 @@ public class BinarySerde {
|
||||||
allocated.put(shapeBuffer);
|
allocated.put(shapeBuffer);
|
||||||
allocated.put(buffer);
|
allocated.put(buffer);
|
||||||
if (rewind)
|
if (rewind)
|
||||||
allocated.rewind();
|
((Buffer) allocated).rewind();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -247,7 +248,7 @@ public class BinarySerde {
|
||||||
//finally put the data
|
//finally put the data
|
||||||
allocated.put(buffer);
|
allocated.put(buffer);
|
||||||
if (rewind)
|
if (rewind)
|
||||||
allocated.rewind();
|
((Buffer) allocated).rewind();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -314,7 +315,7 @@ public class BinarySerde {
|
||||||
|
|
||||||
ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
|
ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
|
||||||
|
|
||||||
buffer.position(0);
|
((Buffer) buffer).position(0);
|
||||||
int rank = byteBuffer.getInt();
|
int rank = byteBuffer.getInt();
|
||||||
|
|
||||||
val result = new long[Shape.shapeInfoLength(rank)];
|
val result = new long[Shape.shapeInfoLength(rank)];
|
||||||
|
@ -324,7 +325,7 @@ public class BinarySerde {
|
||||||
|
|
||||||
// skipping two next values (dtype and rank again)
|
// skipping two next values (dtype and rank again)
|
||||||
// please , that this time rank has dtype of LONG, so takes 8 bytes.
|
// please , that this time rank has dtype of LONG, so takes 8 bytes.
|
||||||
byteBuffer.position(16);
|
((Buffer) byteBuffer).position(16);
|
||||||
|
|
||||||
// filling shape information
|
// filling shape information
|
||||||
for (int e = 1; e < Shape.shapeInfoLength(rank); e++) {
|
for (int e = 1; e < Shape.shapeInfoLength(rank); e++) {
|
||||||
|
|
|
@ -36,6 +36,7 @@ import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileInputStream;
|
import java.io.FileInputStream;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.nio.charset.Charset;
|
import java.nio.charset.Charset;
|
||||||
|
@ -492,8 +493,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
|
||||||
byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
|
byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
|
||||||
ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
|
ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
|
||||||
directBuffer.put(pathBytes);
|
directBuffer.put(pathBytes);
|
||||||
directBuffer.rewind();
|
((Buffer) directBuffer).rewind();
|
||||||
directBuffer.position(0);
|
((Buffer) directBuffer).position(0);
|
||||||
Pointer pointer = nativeOps.numpyFromFile(new BytePointer(directBuffer));
|
Pointer pointer = nativeOps.numpyFromFile(new BytePointer(directBuffer));
|
||||||
|
|
||||||
INDArray result = createFromNpyPointer(pointer);
|
INDArray result = createFromNpyPointer(pointer);
|
||||||
|
@ -672,8 +673,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
|
||||||
byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
|
byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8"));
|
||||||
ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
|
ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder());
|
||||||
directBuffer.put(pathBytes);
|
directBuffer.put(pathBytes);
|
||||||
directBuffer.rewind();
|
((Buffer) directBuffer).rewind();
|
||||||
directBuffer.position(0);
|
((Buffer) directBuffer).position(0);
|
||||||
Pointer pointer = nativeOps.mapFromNpzFile(new BytePointer(directBuffer));
|
Pointer pointer = nativeOps.mapFromNpzFile(new BytePointer(directBuffer));
|
||||||
int n = nativeOps.getNumNpyArraysInMap(pointer);
|
int n = nativeOps.getNumNpyArraysInMap(pointer);
|
||||||
HashMap<String, INDArray> map = new HashMap<>();
|
HashMap<String, INDArray> map = new HashMap<>();
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.Closeable;
|
import java.io.Closeable;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -129,7 +130,7 @@ public class AeronNDArrayPublisher implements AutoCloseable {
|
||||||
NDArrayMessageChunk[] chunks = NDArrayMessage.chunks(message, publication.maxMessageLength() / 128);
|
NDArrayMessageChunk[] chunks = NDArrayMessage.chunks(message, publication.maxMessageLength() / 128);
|
||||||
for (int i = 0; i < chunks.length; i++) {
|
for (int i = 0; i < chunks.length; i++) {
|
||||||
ByteBuffer sendBuff = NDArrayMessageChunk.toBuffer(chunks[i]);
|
ByteBuffer sendBuff = NDArrayMessageChunk.toBuffer(chunks[i]);
|
||||||
sendBuff.rewind();
|
((Buffer) sendBuff).rewind();
|
||||||
DirectBuffer buffer = new UnsafeBuffer(sendBuff);
|
DirectBuffer buffer = new UnsafeBuffer(sendBuff);
|
||||||
sendBuffer(buffer);
|
sendBuffer(buffer);
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
@ -229,7 +230,7 @@ public class NDArrayMessage implements Serializable {
|
||||||
for (int i = 0; i < chunks.length; i++) {
|
for (int i = 0; i < chunks.length; i++) {
|
||||||
ByteBuffer curr = chunks[i].getData();
|
ByteBuffer curr = chunks[i].getData();
|
||||||
if (curr.capacity() > chunks[0].getChunkSize()) {
|
if (curr.capacity() > chunks[0].getChunkSize()) {
|
||||||
curr.position(0).limit(chunks[0].getChunkSize());
|
((Buffer) curr).position(0).limit(chunks[0].getChunkSize());
|
||||||
curr = curr.slice();
|
curr = curr.slice();
|
||||||
}
|
}
|
||||||
all.put(curr);
|
all.put(curr);
|
||||||
|
@ -311,7 +312,7 @@ public class NDArrayMessage implements Serializable {
|
||||||
|
|
||||||
//rewind the buffer before putting it in to the unsafe buffer
|
//rewind the buffer before putting it in to the unsafe buffer
|
||||||
//note that we set rewind to false in the do byte buffer put methods
|
//note that we set rewind to false in the do byte buffer put methods
|
||||||
byteBuffer.rewind();
|
((Buffer) byteBuffer).rewind();
|
||||||
|
|
||||||
return new UnsafeBuffer(byteBuffer);
|
return new UnsafeBuffer(byteBuffer);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue