nd4j-api cleanup. (#8273)
* nd4j-api cleanup. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * restore deleted schemes. Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
1f4ad08305
commit
50b13fadc8
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -31,101 +31,18 @@ public class Nd4jBase64 {
|
|||
|
||||
private Nd4jBase64() {}
|
||||
|
||||
/**
|
||||
* Returns true if the base64
|
||||
* contains multiple arrays
|
||||
* This is delimited by tab
|
||||
* @param base64 the base 64 to test
|
||||
* @return true if the given base 64
|
||||
* is tab delimited or not
|
||||
*/
|
||||
public static boolean isMultiple(String base64) {
|
||||
return base64.contains("\t");
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a set of arrays
|
||||
* from base 64 that is tab delimited.
|
||||
* @param base64 the base 64 that's tab delimited
|
||||
* @return the set of arrays
|
||||
*/
|
||||
public static INDArray[] arraysFromBase64(String base64) throws IOException {
|
||||
String[] base64Arr = base64.split("\t");
|
||||
INDArray[] ret = new INDArray[base64Arr.length];
|
||||
for (int i = 0; i < base64Arr.length; i++) {
|
||||
byte[] decode = Base64.decodeBase64(base64Arr[i]);
|
||||
ByteArrayInputStream bis = new ByteArrayInputStream(decode);
|
||||
DataInputStream dis = new DataInputStream(bis);
|
||||
INDArray predict = Nd4j.read(dis);
|
||||
ret[i] = predict;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a tab delimited base 64
|
||||
* representation of the given arrays
|
||||
* @param arrays the arrays
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public static String arraysToBase64(INDArray[] arrays) throws IOException {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
//tab separate the outputs for de serialization
|
||||
for (INDArray outputArr : arrays) {
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(bos);
|
||||
Nd4j.write(outputArr, dos);
|
||||
String base64 = Base64.encodeBase64String(bos.toByteArray());
|
||||
sb.append(base64);
|
||||
sb.append("\t");
|
||||
}
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert an {@link INDArray}
|
||||
* to numpy byte array using
|
||||
* {@link Nd4j#toNpyByteArray(INDArray)}
|
||||
* @param arr the input array
|
||||
* @return the base 64ed binary
|
||||
* @throws IOException
|
||||
*/
|
||||
public static String base64StringNumpy(INDArray arr) throws IOException {
|
||||
byte[] bytes = Nd4j.toNpyByteArray(arr);
|
||||
return Base64.encodeBase64String(bytes);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert a numpy array from base64
|
||||
* to a byte array and then
|
||||
* create an {@link INDArray}
|
||||
* from {@link Nd4j#createNpyFromByteArray(byte[])}
|
||||
* @param base64 the base 64 byte array
|
||||
* @return the created {@link INDArray}
|
||||
*/
|
||||
public static INDArray fromNpyBase64(String base64) {
|
||||
byte[] bytes = Base64.decodeBase64(base64);
|
||||
return Nd4j.createNpyFromByteArray(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an ndarray
|
||||
* as base 64
|
||||
* @param arr the array to write
|
||||
* @return the base 64 representation of the binary
|
||||
* ndarray
|
||||
* @throws IOException
|
||||
*/
|
||||
public static String base64String(INDArray arr) throws IOException {
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(bos);
|
||||
Nd4j.write(arr, dos);
|
||||
String base64 = Base64.encodeBase64String(bos.toByteArray());
|
||||
return base64;
|
||||
return Base64.encodeBase64String(bos.toByteArray());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -133,14 +50,11 @@ public class Nd4jBase64 {
|
|||
* representation
|
||||
* @param base64 the base 64 to convert
|
||||
* @return the ndarray from base 64
|
||||
* @throws IOException
|
||||
*/
|
||||
public static INDArray fromBase64(String base64) throws IOException {
|
||||
public static INDArray fromBase64(String base64) {
|
||||
byte[] arr = Base64.decodeBase64(base64);
|
||||
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
|
||||
DataInputStream dis = new DataInputStream(bis);
|
||||
INDArray predict = Nd4j.read(dis);
|
||||
return predict;
|
||||
return Nd4j.read(dis);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -42,7 +42,6 @@ import java.nio.channels.WritableByteChannel;
|
|||
@Slf4j
|
||||
public class BinarySerde {
|
||||
|
||||
|
||||
/**
|
||||
* Create an ndarray
|
||||
* from the unsafe buffer
|
||||
|
@ -63,15 +62,13 @@ public class BinarySerde {
|
|||
return toArray(buffer, 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Create an ndarray and existing bytebuffer
|
||||
* @param buffer
|
||||
* @param offset
|
||||
* @return
|
||||
* @param buffer the buffer to create the arrays from
|
||||
* @param offset position in buffer to create the arrays from.
|
||||
* @return the created INDArray and Bytebuffer pair.
|
||||
*/
|
||||
public static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
|
||||
protected static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
|
||||
ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array())
|
||||
.order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
||||
//bump the byte buffer to the proper position
|
||||
|
@ -272,7 +269,7 @@ public class BinarySerde {
|
|||
* binary format
|
||||
* @param arr the array to write
|
||||
* @param toWrite the file tow rite to
|
||||
* @throws IOException
|
||||
* @throws IOException on an I/O exception.
|
||||
*/
|
||||
public static void writeArrayToDisk(INDArray arr, File toWrite) throws IOException {
|
||||
try (FileOutputStream os = new FileOutputStream(toWrite)) {
|
||||
|
@ -285,27 +282,25 @@ public class BinarySerde {
|
|||
|
||||
/**
|
||||
* Read an ndarray from disk
|
||||
* @param readFrom
|
||||
* @return
|
||||
* @throws IOException
|
||||
* @param readFrom file to read
|
||||
* @return the created INDArray.
|
||||
* @throws IOException on an I/O exception.
|
||||
*/
|
||||
public static INDArray readFromDisk(File readFrom) throws IOException {
|
||||
try (FileInputStream os = new FileInputStream(readFrom)) {
|
||||
FileChannel channel = os.getChannel();
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect((int) readFrom.length());
|
||||
channel.read(buffer);
|
||||
INDArray ret = toArray(buffer);
|
||||
return ret;
|
||||
return toArray(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This method returns shape databuffer from saved earlier file
|
||||
*
|
||||
* @param readFrom
|
||||
* @return
|
||||
* @throws IOException
|
||||
* @param readFrom file to read
|
||||
* @return the created databuffer,
|
||||
* @throws IOException on an I/O exception.
|
||||
*/
|
||||
public static DataBuffer readShapeFromDisk(File readFrom) throws IOException {
|
||||
try (FileInputStream os = new FileInputStream(readFrom)) {
|
||||
|
@ -315,8 +310,7 @@ public class BinarySerde {
|
|||
ByteBuffer buffer = ByteBuffer.allocateDirect(len);
|
||||
channel.read(buffer);
|
||||
|
||||
ByteBuffer byteBuffer = buffer == null ? ByteBuffer.allocateDirect(buffer.array().length)
|
||||
.put(buffer.array()).order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
|
||||
ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
|
||||
|
||||
buffer.position(0);
|
||||
int rank = byteBuffer.getInt();
|
||||
|
@ -336,10 +330,7 @@ public class BinarySerde {
|
|||
}
|
||||
|
||||
// creating nd4j databuffer now
|
||||
DataBuffer dataBuffer = Nd4j.getDataBufferFactory().createLong(result);
|
||||
return dataBuffer;
|
||||
return Nd4j.getDataBufferFactory().createLong(result);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -34,8 +34,6 @@ public class NDArrayDeSerializer extends JsonDeserializer<INDArray> {
|
|||
public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException {
|
||||
JsonNode node = jp.getCodec().readTree(jp);
|
||||
String field = node.get("array").asText();
|
||||
INDArray ret = Nd4jBase64.fromBase64(field);
|
||||
return ret;
|
||||
|
||||
return Nd4jBase64.fromBase64(field);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,10 +17,8 @@
|
|||
package org.nd4j.versioncheck;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.nd4j.config.ND4JSystemProperties;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
|
@ -45,7 +43,6 @@ public class VersionCheck {
|
|||
@Deprecated
|
||||
public static final String VERSION_CHECK_PROPERTY = ND4JSystemProperties.VERSION_CHECK_PROPERTY;
|
||||
public static final String GIT_PROPERTY_FILE_SUFFIX = "-git.properties";
|
||||
public static final String PROPERTIES_FILE_SUFFIX = "properties";
|
||||
|
||||
private static final String SCALA_210_SUFFIX = "_2.10";
|
||||
private static final String SCALA_211_SUFFIX = "_2.11";
|
||||
|
|
|
@ -20,9 +20,11 @@ import lombok.AllArgsConstructor;
|
|||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
|
||||
import java.io.*;
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.Properties;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
/*******************************************************************************
|
||||
/* *****************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
|
@ -16,16 +16,12 @@
|
|||
|
||||
package org.nd4j.serde.base64;
|
||||
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
/**
|
||||
|
@ -42,30 +38,11 @@ public class Nd4jBase64Test extends BaseNd4jTest {
|
|||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBase64Several() throws IOException {
|
||||
INDArray[] arrs = new INDArray[2];
|
||||
arrs[0] = Nd4j.linspace(1, 4, 4);
|
||||
arrs[1] = arrs[0].dup();
|
||||
assertArrayEquals(arrs, Nd4jBase64.arraysFromBase64(Nd4jBase64.arraysToBase64(arrs)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBase64() throws Exception {
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4);
|
||||
String base64 = Nd4jBase64.base64String(arr);
|
||||
// assertTrue(Nd4jBase64.isMultiple(base64));
|
||||
INDArray from = Nd4jBase64.fromBase64(base64);
|
||||
assertEquals(arr, from);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
|
||||
public void testBase64Npy() throws Exception {
|
||||
INDArray arr = Nd4j.linspace(1, 4, 4);
|
||||
String base64Npy = Nd4jBase64.base64StringNumpy(arr);
|
||||
INDArray fromBase64 = Nd4jBase64.fromNpyBase64(base64Npy);
|
||||
assertEquals(arr,fromBase64);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue