nd4j-api cleanup. (#8273)
* nd4j-api cleanup. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * restore deleted schemes. Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
1f4ad08305
commit
50b13fadc8
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -31,101 +31,18 @@ public class Nd4jBase64 {
|
||||||
|
|
||||||
private Nd4jBase64() {}
|
private Nd4jBase64() {}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns true if the base64
|
|
||||||
* contains multiple arrays
|
|
||||||
* This is delimited by tab
|
|
||||||
* @param base64 the base 64 to test
|
|
||||||
* @return true if the given base 64
|
|
||||||
* is tab delimited or not
|
|
||||||
*/
|
|
||||||
public static boolean isMultiple(String base64) {
|
|
||||||
return base64.contains("\t");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a set of arrays
|
|
||||||
* from base 64 that is tab delimited.
|
|
||||||
* @param base64 the base 64 that's tab delimited
|
|
||||||
* @return the set of arrays
|
|
||||||
*/
|
|
||||||
public static INDArray[] arraysFromBase64(String base64) throws IOException {
|
|
||||||
String[] base64Arr = base64.split("\t");
|
|
||||||
INDArray[] ret = new INDArray[base64Arr.length];
|
|
||||||
for (int i = 0; i < base64Arr.length; i++) {
|
|
||||||
byte[] decode = Base64.decodeBase64(base64Arr[i]);
|
|
||||||
ByteArrayInputStream bis = new ByteArrayInputStream(decode);
|
|
||||||
DataInputStream dis = new DataInputStream(bis);
|
|
||||||
INDArray predict = Nd4j.read(dis);
|
|
||||||
ret[i] = predict;
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a tab delimited base 64
|
|
||||||
* representation of the given arrays
|
|
||||||
* @param arrays the arrays
|
|
||||||
* @return
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
public static String arraysToBase64(INDArray[] arrays) throws IOException {
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
//tab separate the outputs for de serialization
|
|
||||||
for (INDArray outputArr : arrays) {
|
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
|
||||||
DataOutputStream dos = new DataOutputStream(bos);
|
|
||||||
Nd4j.write(outputArr, dos);
|
|
||||||
String base64 = Base64.encodeBase64String(bos.toByteArray());
|
|
||||||
sb.append(base64);
|
|
||||||
sb.append("\t");
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert an {@link INDArray}
|
|
||||||
* to numpy byte array using
|
|
||||||
* {@link Nd4j#toNpyByteArray(INDArray)}
|
|
||||||
* @param arr the input array
|
|
||||||
* @return the base 64ed binary
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
public static String base64StringNumpy(INDArray arr) throws IOException {
|
|
||||||
byte[] bytes = Nd4j.toNpyByteArray(arr);
|
|
||||||
return Base64.encodeBase64String(bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a numpy array from base64
|
|
||||||
* to a byte array and then
|
|
||||||
* create an {@link INDArray}
|
|
||||||
* from {@link Nd4j#createNpyFromByteArray(byte[])}
|
|
||||||
* @param base64 the base 64 byte array
|
|
||||||
* @return the created {@link INDArray}
|
|
||||||
*/
|
|
||||||
public static INDArray fromNpyBase64(String base64) {
|
|
||||||
byte[] bytes = Base64.decodeBase64(base64);
|
|
||||||
return Nd4j.createNpyFromByteArray(bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an ndarray
|
* Returns an ndarray
|
||||||
* as base 64
|
* as base 64
|
||||||
* @param arr the array to write
|
* @param arr the array to write
|
||||||
* @return the base 64 representation of the binary
|
* @return the base 64 representation of the binary
|
||||||
* ndarray
|
* ndarray
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static String base64String(INDArray arr) throws IOException {
|
public static String base64String(INDArray arr) throws IOException {
|
||||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||||
DataOutputStream dos = new DataOutputStream(bos);
|
DataOutputStream dos = new DataOutputStream(bos);
|
||||||
Nd4j.write(arr, dos);
|
Nd4j.write(arr, dos);
|
||||||
String base64 = Base64.encodeBase64String(bos.toByteArray());
|
return Base64.encodeBase64String(bos.toByteArray());
|
||||||
return base64;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -133,14 +50,11 @@ public class Nd4jBase64 {
|
||||||
* representation
|
* representation
|
||||||
* @param base64 the base 64 to convert
|
* @param base64 the base 64 to convert
|
||||||
* @return the ndarray from base 64
|
* @return the ndarray from base 64
|
||||||
* @throws IOException
|
|
||||||
*/
|
*/
|
||||||
public static INDArray fromBase64(String base64) throws IOException {
|
public static INDArray fromBase64(String base64) {
|
||||||
byte[] arr = Base64.decodeBase64(base64);
|
byte[] arr = Base64.decodeBase64(base64);
|
||||||
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
|
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
|
||||||
DataInputStream dis = new DataInputStream(bis);
|
DataInputStream dis = new DataInputStream(bis);
|
||||||
INDArray predict = Nd4j.read(dis);
|
return Nd4j.read(dis);
|
||||||
return predict;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -42,7 +42,6 @@ import java.nio.channels.WritableByteChannel;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class BinarySerde {
|
public class BinarySerde {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create an ndarray
|
* Create an ndarray
|
||||||
* from the unsafe buffer
|
* from the unsafe buffer
|
||||||
|
@ -63,15 +62,13 @@ public class BinarySerde {
|
||||||
return toArray(buffer, 0);
|
return toArray(buffer, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create an ndarray and existing bytebuffer
|
* Create an ndarray and existing bytebuffer
|
||||||
* @param buffer
|
* @param buffer the buffer to create the arrays from
|
||||||
* @param offset
|
* @param offset position in buffer to create the arrays from.
|
||||||
* @return
|
* @return the created INDArray and Bytebuffer pair.
|
||||||
*/
|
*/
|
||||||
public static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
|
protected static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
|
||||||
ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array())
|
ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array())
|
||||||
.order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
.order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
||||||
//bump the byte buffer to the proper position
|
//bump the byte buffer to the proper position
|
||||||
|
@ -272,7 +269,7 @@ public class BinarySerde {
|
||||||
* binary format
|
* binary format
|
||||||
* @param arr the array to write
|
* @param arr the array to write
|
||||||
* @param toWrite the file tow rite to
|
* @param toWrite the file tow rite to
|
||||||
* @throws IOException
|
* @throws IOException on an I/O exception.
|
||||||
*/
|
*/
|
||||||
public static void writeArrayToDisk(INDArray arr, File toWrite) throws IOException {
|
public static void writeArrayToDisk(INDArray arr, File toWrite) throws IOException {
|
||||||
try (FileOutputStream os = new FileOutputStream(toWrite)) {
|
try (FileOutputStream os = new FileOutputStream(toWrite)) {
|
||||||
|
@ -285,27 +282,25 @@ public class BinarySerde {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read an ndarray from disk
|
* Read an ndarray from disk
|
||||||
* @param readFrom
|
* @param readFrom file to read
|
||||||
* @return
|
* @return the created INDArray.
|
||||||
* @throws IOException
|
* @throws IOException on an I/O exception.
|
||||||
*/
|
*/
|
||||||
public static INDArray readFromDisk(File readFrom) throws IOException {
|
public static INDArray readFromDisk(File readFrom) throws IOException {
|
||||||
try (FileInputStream os = new FileInputStream(readFrom)) {
|
try (FileInputStream os = new FileInputStream(readFrom)) {
|
||||||
FileChannel channel = os.getChannel();
|
FileChannel channel = os.getChannel();
|
||||||
ByteBuffer buffer = ByteBuffer.allocateDirect((int) readFrom.length());
|
ByteBuffer buffer = ByteBuffer.allocateDirect((int) readFrom.length());
|
||||||
channel.read(buffer);
|
channel.read(buffer);
|
||||||
INDArray ret = toArray(buffer);
|
return toArray(buffer);
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns shape databuffer from saved earlier file
|
* This method returns shape databuffer from saved earlier file
|
||||||
*
|
*
|
||||||
* @param readFrom
|
* @param readFrom file to read
|
||||||
* @return
|
* @return the created databuffer,
|
||||||
* @throws IOException
|
* @throws IOException on an I/O exception.
|
||||||
*/
|
*/
|
||||||
public static DataBuffer readShapeFromDisk(File readFrom) throws IOException {
|
public static DataBuffer readShapeFromDisk(File readFrom) throws IOException {
|
||||||
try (FileInputStream os = new FileInputStream(readFrom)) {
|
try (FileInputStream os = new FileInputStream(readFrom)) {
|
||||||
|
@ -315,8 +310,7 @@ public class BinarySerde {
|
||||||
ByteBuffer buffer = ByteBuffer.allocateDirect(len);
|
ByteBuffer buffer = ByteBuffer.allocateDirect(len);
|
||||||
channel.read(buffer);
|
channel.read(buffer);
|
||||||
|
|
||||||
ByteBuffer byteBuffer = buffer == null ? ByteBuffer.allocateDirect(buffer.array().length)
|
ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
|
||||||
.put(buffer.array()).order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
|
||||||
|
|
||||||
buffer.position(0);
|
buffer.position(0);
|
||||||
int rank = byteBuffer.getInt();
|
int rank = byteBuffer.getInt();
|
||||||
|
@ -336,10 +330,7 @@ public class BinarySerde {
|
||||||
}
|
}
|
||||||
|
|
||||||
// creating nd4j databuffer now
|
// creating nd4j databuffer now
|
||||||
DataBuffer dataBuffer = Nd4j.getDataBufferFactory().createLong(result);
|
return Nd4j.getDataBufferFactory().createLong(result);
|
||||||
return dataBuffer;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,8 +34,6 @@ public class NDArrayDeSerializer extends JsonDeserializer<INDArray> {
|
||||||
public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException {
|
public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException {
|
||||||
JsonNode node = jp.getCodec().readTree(jp);
|
JsonNode node = jp.getCodec().readTree(jp);
|
||||||
String field = node.get("array").asText();
|
String field = node.get("array").asText();
|
||||||
INDArray ret = Nd4jBase64.fromBase64(field);
|
return Nd4jBase64.fromBase64(field);
|
||||||
return ret;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,10 +17,8 @@
|
||||||
package org.nd4j.versioncheck;
|
package org.nd4j.versioncheck;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.nd4j.config.ND4JSystemProperties;
|
import org.nd4j.config.ND4JSystemProperties;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
|
@ -45,7 +43,6 @@ public class VersionCheck {
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static final String VERSION_CHECK_PROPERTY = ND4JSystemProperties.VERSION_CHECK_PROPERTY;
|
public static final String VERSION_CHECK_PROPERTY = ND4JSystemProperties.VERSION_CHECK_PROPERTY;
|
||||||
public static final String GIT_PROPERTY_FILE_SUFFIX = "-git.properties";
|
public static final String GIT_PROPERTY_FILE_SUFFIX = "-git.properties";
|
||||||
public static final String PROPERTIES_FILE_SUFFIX = "properties";
|
|
||||||
|
|
||||||
private static final String SCALA_210_SUFFIX = "_2.10";
|
private static final String SCALA_210_SUFFIX = "_2.10";
|
||||||
private static final String SCALA_211_SUFFIX = "_2.11";
|
private static final String SCALA_211_SUFFIX = "_2.11";
|
||||||
|
|
|
@ -20,9 +20,11 @@ import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.BufferedInputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -16,16 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.serde.base64;
|
package org.nd4j.serde.base64;
|
||||||
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -42,30 +38,11 @@ public class Nd4jBase64Test extends BaseNd4jTest {
|
||||||
return 'c';
|
return 'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBase64Several() throws IOException {
|
|
||||||
INDArray[] arrs = new INDArray[2];
|
|
||||||
arrs[0] = Nd4j.linspace(1, 4, 4);
|
|
||||||
arrs[1] = arrs[0].dup();
|
|
||||||
assertArrayEquals(arrs, Nd4jBase64.arraysFromBase64(Nd4jBase64.arraysToBase64(arrs)));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBase64() throws Exception {
|
public void testBase64() throws Exception {
|
||||||
INDArray arr = Nd4j.linspace(1, 4, 4);
|
INDArray arr = Nd4j.linspace(1, 4, 4);
|
||||||
String base64 = Nd4jBase64.base64String(arr);
|
String base64 = Nd4jBase64.base64String(arr);
|
||||||
// assertTrue(Nd4jBase64.isMultiple(base64));
|
|
||||||
INDArray from = Nd4jBase64.fromBase64(base64);
|
INDArray from = Nd4jBase64.fromBase64(base64);
|
||||||
assertEquals(arr, from);
|
assertEquals(arr, from);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
|
||||||
public void testBase64Npy() throws Exception {
|
|
||||||
INDArray arr = Nd4j.linspace(1, 4, 4);
|
|
||||||
String base64Npy = Nd4jBase64.base64StringNumpy(arr);
|
|
||||||
INDArray fromBase64 = Nd4jBase64.fromNpyBase64(base64Npy);
|
|
||||||
assertEquals(arr,fromBase64);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue