2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
* ******************************************************************************
|
|
|
|
* *
|
|
|
|
* *
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
* * under the License.
|
|
|
|
* *
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
* *****************************************************************************
|
|
|
|
*/
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
package org.nd4j.serde.binary;
|
|
|
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import lombok.val;
|
|
|
|
import org.bytedeco.javacpp.BytePointer;
|
|
|
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
|
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
import org.nd4j.linalg.api.shape.Shape;
|
|
|
|
import org.nd4j.linalg.compression.CompressedDataBuffer;
|
|
|
|
import org.nd4j.linalg.compression.CompressionDescriptor;
|
2019-10-31 11:23:09 +02:00
|
|
|
import org.nd4j.linalg.exception.ND4JArraySizeException;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.factory.Nd4j;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.nd4j.common.primitives.Pair;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
import java.io.*;
|
2019-11-20 07:43:17 +03:00
|
|
|
import java.nio.Buffer;
|
2019-06-06 15:21:15 +03:00
|
|
|
import java.nio.ByteBuffer;
|
|
|
|
import java.nio.ByteOrder;
|
|
|
|
import java.nio.channels.Channels;
|
|
|
|
import java.nio.channels.FileChannel;
|
|
|
|
import java.nio.channels.WritableByteChannel;
|
|
|
|
|
|
|
|
@Slf4j
|
|
|
|
public class BinarySerde {
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Create an ndarray
|
|
|
|
* from the unsafe buffer
|
|
|
|
* @param buffer the buffer to create the array from
|
|
|
|
* @return the ndarray derived from this buffer
|
|
|
|
*/
|
|
|
|
public static INDArray toArray(ByteBuffer buffer, int offset) {
|
|
|
|
return toArrayAndByteBuffer(buffer, offset).getLeft();
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Create an ndarray
|
|
|
|
* from the unsafe buffer
|
|
|
|
* @param buffer the buffer to create the array from
|
|
|
|
* @return the ndarray derived from this buffer
|
|
|
|
*/
|
|
|
|
public static INDArray toArray(ByteBuffer buffer) {
|
|
|
|
return toArray(buffer, 0);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Create an ndarray and existing bytebuffer
|
2019-10-08 19:48:22 +09:00
|
|
|
* @param buffer the buffer to create the arrays from
|
|
|
|
* @param offset position in buffer to create the arrays from.
|
|
|
|
* @return the created INDArray and Bytebuffer pair.
|
2019-06-06 15:21:15 +03:00
|
|
|
*/
|
2019-10-08 19:48:22 +09:00
|
|
|
protected static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
|
2019-06-06 15:21:15 +03:00
|
|
|
ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array())
|
|
|
|
.order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
|
|
|
//bump the byte buffer to the proper position
|
2022-09-20 15:40:53 +02:00
|
|
|
byteBuffer.position(offset);
|
2019-06-06 15:21:15 +03:00
|
|
|
int rank = byteBuffer.getInt();
|
|
|
|
if (rank < 0)
|
|
|
|
throw new IllegalStateException("Found negative integer. Corrupt serialization?");
|
|
|
|
//get the shape buffer length to create the shape information buffer
|
|
|
|
int shapeBufferLength = Shape.shapeInfoLength(rank);
|
|
|
|
//create the ndarray shape information
|
|
|
|
DataBuffer shapeBuff = Nd4j.createBufferDetached(new int[shapeBufferLength]);
|
|
|
|
|
|
|
|
//compute the databuffer opType from the index
|
|
|
|
DataType type = DataType.values()[byteBuffer.getInt()];
|
|
|
|
for (int i = 0; i < shapeBufferLength; i++) {
|
|
|
|
shapeBuff.put(i, byteBuffer.getLong());
|
|
|
|
}
|
|
|
|
|
|
|
|
//after the rank,data opType, shape buffer (of length shape buffer length) * sizeof(int)
|
|
|
|
if (type != DataType.COMPRESSED) {
|
|
|
|
ByteBuffer slice = byteBuffer.slice();
|
|
|
|
//wrap the data buffer for the last bit
|
2019-10-31 11:23:09 +02:00
|
|
|
if (Shape.length(shapeBuff) > Integer.MAX_VALUE)
|
|
|
|
throw new ND4JArraySizeException();
|
2019-06-06 15:21:15 +03:00
|
|
|
DataBuffer buff = Nd4j.createBuffer(slice, type, (int) Shape.length(shapeBuff));
|
|
|
|
//advance past the data
|
|
|
|
int position = byteBuffer.position() + (buff.getElementSize() * (int) buff.length());
|
2022-09-20 15:40:53 +02:00
|
|
|
byteBuffer.position(position);
|
2019-06-06 15:21:15 +03:00
|
|
|
//create the final array
|
|
|
|
//TODO: see how to avoid dup here
|
|
|
|
INDArray arr = Nd4j.createArrayFromShapeBuffer(buff.dup(), shapeBuff.dup());
|
|
|
|
return Pair.of(arr, byteBuffer);
|
|
|
|
} else {
|
|
|
|
CompressionDescriptor compressionDescriptor = CompressionDescriptor.fromByteBuffer(byteBuffer);
|
|
|
|
ByteBuffer slice = byteBuffer.slice();
|
|
|
|
//ensure that we only deal with the slice of the buffer that is actually the data
|
|
|
|
BytePointer byteBufferPointer = new BytePointer(slice);
|
|
|
|
//create a compressed array based on the rest of the data left in the buffer
|
|
|
|
CompressedDataBuffer compressedDataBuffer =
|
|
|
|
new CompressedDataBuffer(byteBufferPointer, compressionDescriptor);
|
|
|
|
//TODO: see how to avoid dup()
|
|
|
|
INDArray arr = Nd4j.createArrayFromShapeBuffer(compressedDataBuffer.dup(), shapeBuff.dup());
|
|
|
|
//advance past the data
|
|
|
|
int compressLength = (int) compressionDescriptor.getCompressedLength();
|
2022-09-20 15:40:53 +02:00
|
|
|
byteBuffer.position(byteBuffer.position() + compressLength);
|
2019-06-06 15:21:15 +03:00
|
|
|
return Pair.of(arr, byteBuffer);
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Convert an ndarray to an unsafe buffer
|
|
|
|
* for use by aeron
|
|
|
|
* @param arr the array to convert
|
|
|
|
* @return the unsafebuffer representation of this array
|
|
|
|
*/
|
|
|
|
public static ByteBuffer toByteBuffer(INDArray arr) {
|
|
|
|
//subset and get rid of 1 off non 1 element wise stride cases
|
|
|
|
if (arr.isView())
|
|
|
|
arr = arr.dup();
|
|
|
|
if (!arr.isCompressed()) {
|
|
|
|
ByteBuffer b3 = ByteBuffer.allocateDirect(byteBufferSizeFor(arr)).order(ByteOrder.nativeOrder());
|
|
|
|
doByteBufferPutUnCompressed(arr, b3, true);
|
|
|
|
return b3;
|
|
|
|
}
|
|
|
|
//compressed array
|
|
|
|
else {
|
|
|
|
ByteBuffer b3 = ByteBuffer.allocateDirect(byteBufferSizeFor(arr)).order(ByteOrder.nativeOrder());
|
|
|
|
doByteBufferPutCompressed(arr, b3, true);
|
|
|
|
return b3;
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns the byte buffer size for the given
|
|
|
|
* ndarray. This is an auxillary method
|
|
|
|
* for determining the size of the buffer
|
|
|
|
* size to allocate for sending an ndarray via
|
|
|
|
* the aeron media driver.
|
|
|
|
*
|
|
|
|
* The math break down for uncompressed is:
|
|
|
|
* 2 ints for rank of the array and an ordinal representing the data opType of the data buffer
|
|
|
|
* The rest is in order:
|
|
|
|
* shape information
|
|
|
|
* data buffer
|
|
|
|
*
|
|
|
|
* The math break down for compressed is:
|
|
|
|
* 2 ints for rank and an ordinal representing the data opType for the data buffer
|
|
|
|
*
|
|
|
|
* The rest is in order:
|
|
|
|
* shape information
|
|
|
|
* codec information
|
|
|
|
* data buffer
|
|
|
|
*
|
|
|
|
* @param arr the array to compute the size for
|
|
|
|
* @return the size of the byte buffer that was allocated
|
|
|
|
*/
|
|
|
|
public static int byteBufferSizeFor(INDArray arr) {
|
|
|
|
if (!arr.isCompressed()) {
|
|
|
|
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
//2 four byte ints at the beginning
|
|
|
|
int twoInts = 8;
|
|
|
|
return twoInts + buffer.limit() + shapeBuffer.limit();
|
|
|
|
} else {
|
|
|
|
CompressedDataBuffer compressedDataBuffer = (CompressedDataBuffer) arr.data();
|
|
|
|
CompressionDescriptor descriptor = compressedDataBuffer.getCompressionDescriptor();
|
|
|
|
ByteBuffer codecByteBuffer = descriptor.toByteBuffer();
|
|
|
|
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
int twoInts = 2 * 4;
|
|
|
|
return twoInts + buffer.limit() + shapeBuffer.limit() + codecByteBuffer.limit();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Setup the given byte buffer
|
|
|
|
* for serialization (note that this is for uncompressed INDArrays)
|
|
|
|
* 4 bytes int for rank
|
|
|
|
* 4 bytes for data opType
|
|
|
|
* shape buffer
|
|
|
|
* data buffer
|
|
|
|
*
|
|
|
|
* @param arr the array to setup
|
|
|
|
* @param allocated the byte buffer to setup
|
|
|
|
* @param rewind whether to rewind the byte buffer or nt
|
|
|
|
*/
|
|
|
|
public static void doByteBufferPutUnCompressed(INDArray arr, ByteBuffer allocated, boolean rewind) {
|
|
|
|
// ensure we send data to host memory
|
|
|
|
Nd4j.getExecutioner().commit();
|
|
|
|
Nd4j.getAffinityManager().ensureLocation(arr, AffinityManager.Location.HOST);
|
|
|
|
|
|
|
|
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
//2 four byte ints at the beginning
|
|
|
|
allocated.putInt(arr.rank());
|
|
|
|
//put data opType next so its self describing
|
|
|
|
allocated.putInt(arr.data().dataType().ordinal());
|
|
|
|
allocated.put(shapeBuffer);
|
|
|
|
allocated.put(buffer);
|
|
|
|
if (rewind)
|
2019-11-20 07:43:17 +03:00
|
|
|
((Buffer) allocated).rewind();
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Setup the given byte buffer
|
|
|
|
* for serialization (note that this is for compressed INDArrays)
|
|
|
|
* 4 bytes for rank
|
|
|
|
* 4 bytes for data opType
|
|
|
|
* shape information
|
|
|
|
* codec information
|
|
|
|
* data opType
|
|
|
|
*
|
|
|
|
* @param arr the array to setup
|
|
|
|
* @param allocated the byte buffer to setup
|
|
|
|
* @param rewind whether to rewind the byte buffer or not
|
|
|
|
*/
|
|
|
|
public static void doByteBufferPutCompressed(INDArray arr, ByteBuffer allocated, boolean rewind) {
|
|
|
|
CompressedDataBuffer compressedDataBuffer = (CompressedDataBuffer) arr.data();
|
|
|
|
CompressionDescriptor descriptor = compressedDataBuffer.getCompressionDescriptor();
|
|
|
|
ByteBuffer codecByteBuffer = descriptor.toByteBuffer();
|
|
|
|
ByteBuffer buffer = arr.data().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
ByteBuffer shapeBuffer = arr.shapeInfoDataBuffer().pointer().asByteBuffer().order(ByteOrder.nativeOrder());
|
|
|
|
allocated.putInt(arr.rank());
|
|
|
|
//put data opType next so its self describing
|
|
|
|
allocated.putInt(arr.data().dataType().ordinal());
|
|
|
|
//put shape next
|
|
|
|
allocated.put(shapeBuffer);
|
|
|
|
//put codec information next
|
|
|
|
allocated.put(codecByteBuffer);
|
|
|
|
//finally put the data
|
|
|
|
allocated.put(buffer);
|
|
|
|
if (rewind)
|
2019-11-20 07:43:17 +03:00
|
|
|
((Buffer) allocated).rewind();
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Write an array to an output stream.
|
|
|
|
* @param arr the array to write
|
|
|
|
* @param outputStream the output stream to write to
|
|
|
|
*/
|
|
|
|
public static void writeArrayToOutputStream(INDArray arr, OutputStream outputStream) {
|
|
|
|
ByteBuffer buffer = BinarySerde.toByteBuffer(arr);
|
|
|
|
try (WritableByteChannel channel = Channels.newChannel(outputStream)) {
|
|
|
|
channel.write(buffer);
|
|
|
|
} catch (IOException e) {
|
2020-04-23 01:36:49 +03:00
|
|
|
log.error("",e);
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Write an ndarray to disk in
|
|
|
|
* binary format
|
|
|
|
* @param arr the array to write
|
|
|
|
* @param toWrite the file tow rite to
|
2019-10-08 19:48:22 +09:00
|
|
|
* @throws IOException on an I/O exception.
|
2019-06-06 15:21:15 +03:00
|
|
|
*/
|
|
|
|
public static void writeArrayToDisk(INDArray arr, File toWrite) throws IOException {
|
|
|
|
try (FileOutputStream os = new FileOutputStream(toWrite)) {
|
|
|
|
FileChannel channel = os.getChannel();
|
|
|
|
ByteBuffer buffer = BinarySerde.toByteBuffer(arr);
|
|
|
|
channel.write(buffer);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Read an ndarray from disk
|
2019-10-08 19:48:22 +09:00
|
|
|
* @param readFrom file to read
|
|
|
|
* @return the created INDArray.
|
|
|
|
* @throws IOException on an I/O exception.
|
2019-06-06 15:21:15 +03:00
|
|
|
*/
|
|
|
|
public static INDArray readFromDisk(File readFrom) throws IOException {
|
|
|
|
try (FileInputStream os = new FileInputStream(readFrom)) {
|
|
|
|
FileChannel channel = os.getChannel();
|
|
|
|
ByteBuffer buffer = ByteBuffer.allocateDirect((int) readFrom.length());
|
|
|
|
channel.read(buffer);
|
2019-10-08 19:48:22 +09:00
|
|
|
return toArray(buffer);
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* This method returns shape databuffer from saved earlier file
|
|
|
|
*
|
2019-10-08 19:48:22 +09:00
|
|
|
* @param readFrom file to read
|
|
|
|
* @return the created databuffer,
|
|
|
|
* @throws IOException on an I/O exception.
|
2019-06-06 15:21:15 +03:00
|
|
|
*/
|
|
|
|
public static DataBuffer readShapeFromDisk(File readFrom) throws IOException {
|
|
|
|
try (FileInputStream os = new FileInputStream(readFrom)) {
|
|
|
|
FileChannel channel = os.getChannel();
|
|
|
|
// we read shapeinfo up to max_rank value, which is 32
|
|
|
|
int len = (int) Math.min((32 * 2 + 3) * 8, readFrom.length());
|
|
|
|
ByteBuffer buffer = ByteBuffer.allocateDirect(len);
|
|
|
|
channel.read(buffer);
|
|
|
|
|
2019-10-08 19:48:22 +09:00
|
|
|
ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
|
2019-06-06 15:21:15 +03:00
|
|
|
|
2019-11-20 07:43:17 +03:00
|
|
|
((Buffer) buffer).position(0);
|
2019-06-06 15:21:15 +03:00
|
|
|
int rank = byteBuffer.getInt();
|
|
|
|
|
|
|
|
val result = new long[Shape.shapeInfoLength(rank)];
|
|
|
|
|
|
|
|
// filling DataBuffer with shape info
|
|
|
|
result[0] = rank;
|
|
|
|
|
|
|
|
// skipping two next values (dtype and rank again)
|
|
|
|
// please , that this time rank has dtype of LONG, so takes 8 bytes.
|
2019-11-20 07:43:17 +03:00
|
|
|
((Buffer) byteBuffer).position(16);
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
// filling shape information
|
|
|
|
for (int e = 1; e < Shape.shapeInfoLength(rank); e++) {
|
|
|
|
result[e] = byteBuffer.getLong();
|
|
|
|
}
|
|
|
|
|
|
|
|
// creating nd4j databuffer now
|
2019-10-08 19:48:22 +09:00
|
|
|
return Nd4j.getDataBufferFactory().createLong(result);
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|