nd4j-api cleanup. (#8273)

* nd4j-api cleanup.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* restore deleted schemes.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-10-08 19:48:22 +09:00 committed by Alex Black
parent 1f4ad08305
commit 50b13fadc8
6 changed files with 26 additions and 147 deletions

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -31,101 +31,18 @@ public class Nd4jBase64 {
private Nd4jBase64() {}
/**
* Returns true if the base64
* contains multiple arrays
* This is delimited by tab
* @param base64 the base 64 to test
* @return true if the given base 64
* is tab delimited or not
*/
public static boolean isMultiple(String base64) {
return base64.contains("\t");
}
/**
* Returns a set of arrays
* from base 64 that is tab delimited.
* @param base64 the base 64 that's tab delimited
* @return the set of arrays
*/
public static INDArray[] arraysFromBase64(String base64) throws IOException {
String[] base64Arr = base64.split("\t");
INDArray[] ret = new INDArray[base64Arr.length];
for (int i = 0; i < base64Arr.length; i++) {
byte[] decode = Base64.decodeBase64(base64Arr[i]);
ByteArrayInputStream bis = new ByteArrayInputStream(decode);
DataInputStream dis = new DataInputStream(bis);
INDArray predict = Nd4j.read(dis);
ret[i] = predict;
}
return ret;
}
/**
* Returns a tab delimited base 64
* representation of the given arrays
* @param arrays the arrays
* @return
* @throws IOException
*/
public static String arraysToBase64(INDArray[] arrays) throws IOException {
StringBuilder sb = new StringBuilder();
//tab separate the outputs for de serialization
for (INDArray outputArr : arrays) {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
Nd4j.write(outputArr, dos);
String base64 = Base64.encodeBase64String(bos.toByteArray());
sb.append(base64);
sb.append("\t");
}
return sb.toString();
}
/**
* Convert an {@link INDArray}
* to numpy byte array using
* {@link Nd4j#toNpyByteArray(INDArray)}
* @param arr the input array
* @return the base 64ed binary
* @throws IOException
*/
public static String base64StringNumpy(INDArray arr) throws IOException {
byte[] bytes = Nd4j.toNpyByteArray(arr);
return Base64.encodeBase64String(bytes);
}
/**
* Convert a numpy array from base64
* to a byte array and then
* create an {@link INDArray}
* from {@link Nd4j#createNpyFromByteArray(byte[])}
* @param base64 the base 64 byte array
* @return the created {@link INDArray}
*/
public static INDArray fromNpyBase64(String base64) {
byte[] bytes = Base64.decodeBase64(base64);
return Nd4j.createNpyFromByteArray(bytes);
}
/**
* Returns an ndarray
* as base 64
* @param arr the array to write
* @return the base 64 representation of the binary
* ndarray
* @throws IOException
*/
public static String base64String(INDArray arr) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
Nd4j.write(arr, dos);
String base64 = Base64.encodeBase64String(bos.toByteArray());
return base64;
return Base64.encodeBase64String(bos.toByteArray());
}
/**
@ -133,14 +50,11 @@ public class Nd4jBase64 {
* representation
* @param base64 the base 64 to convert
* @return the ndarray from base 64
* @throws IOException
*/
public static INDArray fromBase64(String base64) throws IOException {
public static INDArray fromBase64(String base64) {
byte[] arr = Base64.decodeBase64(base64);
ByteArrayInputStream bis = new ByteArrayInputStream(arr);
DataInputStream dis = new DataInputStream(bis);
INDArray predict = Nd4j.read(dis);
return predict;
return Nd4j.read(dis);
}
}

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -42,7 +42,6 @@ import java.nio.channels.WritableByteChannel;
@Slf4j
public class BinarySerde {
/**
* Create an ndarray
* from the unsafe buffer
@ -63,15 +62,13 @@ public class BinarySerde {
return toArray(buffer, 0);
}
/**
* Create an ndarray and existing bytebuffer
* @param buffer
* @param offset
* @return
* @param buffer the buffer to create the arrays from
* @param offset position in buffer to create the arrays from.
* @return the created INDArray and Bytebuffer pair.
*/
public static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
protected static Pair<INDArray, ByteBuffer> toArrayAndByteBuffer(ByteBuffer buffer, int offset) {
ByteBuffer byteBuffer = buffer.hasArray() ? ByteBuffer.allocateDirect(buffer.array().length).put(buffer.array())
.order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
//bump the byte buffer to the proper position
@ -272,7 +269,7 @@ public class BinarySerde {
* binary format
* @param arr the array to write
* @param toWrite the file tow rite to
* @throws IOException
* @throws IOException on an I/O exception.
*/
public static void writeArrayToDisk(INDArray arr, File toWrite) throws IOException {
try (FileOutputStream os = new FileOutputStream(toWrite)) {
@ -285,27 +282,25 @@ public class BinarySerde {
/**
* Read an ndarray from disk
* @param readFrom
* @return
* @throws IOException
* @param readFrom file to read
* @return the created INDArray.
* @throws IOException on an I/O exception.
*/
public static INDArray readFromDisk(File readFrom) throws IOException {
try (FileInputStream os = new FileInputStream(readFrom)) {
FileChannel channel = os.getChannel();
ByteBuffer buffer = ByteBuffer.allocateDirect((int) readFrom.length());
channel.read(buffer);
INDArray ret = toArray(buffer);
return ret;
return toArray(buffer);
}
}
/**
* This method returns shape databuffer from saved earlier file
*
* @param readFrom
* @return
* @throws IOException
* @param readFrom file to read
* @return the created databuffer,
* @throws IOException on an I/O exception.
*/
public static DataBuffer readShapeFromDisk(File readFrom) throws IOException {
try (FileInputStream os = new FileInputStream(readFrom)) {
@ -315,8 +310,7 @@ public class BinarySerde {
ByteBuffer buffer = ByteBuffer.allocateDirect(len);
channel.read(buffer);
ByteBuffer byteBuffer = buffer == null ? ByteBuffer.allocateDirect(buffer.array().length)
.put(buffer.array()).order(ByteOrder.nativeOrder()) : buffer.order(ByteOrder.nativeOrder());
ByteBuffer byteBuffer = buffer.order(ByteOrder.nativeOrder());
buffer.position(0);
int rank = byteBuffer.getInt();
@ -336,10 +330,7 @@ public class BinarySerde {
}
// creating nd4j databuffer now
DataBuffer dataBuffer = Nd4j.getDataBufferFactory().createLong(result);
return dataBuffer;
return Nd4j.getDataBufferFactory().createLong(result);
}
}
}

View File

@ -34,8 +34,6 @@ public class NDArrayDeSerializer extends JsonDeserializer<INDArray> {
public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException {
JsonNode node = jp.getCodec().readTree(jp);
String field = node.get("array").asText();
INDArray ret = Nd4jBase64.fromBase64(field);
return ret;
return Nd4jBase64.fromBase64(field);
}
}

View File

@ -17,10 +17,8 @@
package org.nd4j.versioncheck;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.nd4j.config.ND4JSystemProperties;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
@ -45,7 +43,6 @@ public class VersionCheck {
@Deprecated
public static final String VERSION_CHECK_PROPERTY = ND4JSystemProperties.VERSION_CHECK_PROPERTY;
public static final String GIT_PROPERTY_FILE_SUFFIX = "-git.properties";
public static final String PROPERTIES_FILE_SUFFIX = "properties";
private static final String SCALA_210_SUFFIX = "_2.10";
private static final String SCALA_211_SUFFIX = "_2.11";

View File

@ -20,9 +20,11 @@ import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.apache.commons.io.FilenameUtils;
import java.io.*;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.util.Properties;

View File

@ -1,4 +1,4 @@
/*******************************************************************************
/* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
@ -16,16 +16,12 @@
package org.nd4j.serde.base64;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.IOException;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
@ -42,30 +38,11 @@ public class Nd4jBase64Test extends BaseNd4jTest {
return 'c';
}
@Test
public void testBase64Several() throws IOException {
INDArray[] arrs = new INDArray[2];
arrs[0] = Nd4j.linspace(1, 4, 4);
arrs[1] = arrs[0].dup();
assertArrayEquals(arrs, Nd4jBase64.arraysFromBase64(Nd4jBase64.arraysToBase64(arrs)));
}
@Test
public void testBase64() throws Exception {
INDArray arr = Nd4j.linspace(1, 4, 4);
String base64 = Nd4jBase64.base64String(arr);
// assertTrue(Nd4jBase64.isMultiple(base64));
INDArray from = Nd4jBase64.fromBase64(base64);
assertEquals(arr, from);
}
@Test
@Ignore("AB 2019/05/23 - Failing on linux-x86_64-cuda-9.2 - see issue #7657")
public void testBase64Npy() throws Exception {
INDArray arr = Nd4j.linspace(1, 4, 4);
String base64Npy = Nd4jBase64.base64StringNumpy(arr);
INDArray fromBase64 = Nd4jBase64.fromNpyBase64(base64Npy);
assertEquals(arr,fromBase64);
}
}