277 lines
11 KiB
Java
277 lines
11 KiB
Java
|
|
package org.deeplearning4j.remote;
|
||
|
|
|
||
|
|
import lombok.val;
|
||
|
|
import org.datavec.image.loader.Java2DNativeImageLoader;
|
||
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||
|
|
import org.deeplearning4j.remote.helpers.ImageConversionUtils;
|
||
|
|
import org.deeplearning4j.util.ModelSerializer;
|
||
|
|
import org.junit.After;
|
||
|
|
import org.junit.Test;
|
||
|
|
import org.nd4j.adapters.InferenceAdapter;
|
||
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||
|
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||
|
|
import org.nd4j.linalg.io.ClassPathResource;
|
||
|
|
import org.nd4j.remote.clients.JsonRemoteInference;
|
||
|
|
import org.nd4j.remote.clients.serde.BinaryDeserializer;
|
||
|
|
import org.nd4j.remote.clients.serde.BinarySerializer;
|
||
|
|
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||
|
|
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||
|
|
import org.nd4j.remote.clients.serde.impl.IntegerSerde;
|
||
|
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||
|
|
|
||
|
|
import javax.imageio.ImageIO;
|
||
|
|
import java.awt.image.BufferedImage;
|
||
|
|
import java.io.*;
|
||
|
|
import java.nio.Buffer;
|
||
|
|
import java.util.concurrent.Future;
|
||
|
|
import java.util.concurrent.TimeUnit;
|
||
|
|
|
||
|
|
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
||
|
|
import static org.junit.Assert.*;
|
||
|
|
|
||
|
|
public class BinaryModelServerTest {
|
||
|
|
private final int PORT = 18080;
|
||
|
|
|
||
|
|
@After
|
||
|
|
public void pause() throws Exception {
|
||
|
|
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
|
||
|
|
TimeUnit.SECONDS.sleep(2);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Internal test for locally defined serializers
|
||
|
|
@Test
|
||
|
|
public void testBufferedImageSerde() {
|
||
|
|
BinarySerializer<BufferedImage> serde = new BinaryModelServerTest.BufferedImageSerde();
|
||
|
|
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||
|
|
byte[] serialized = serde.serialize(image);
|
||
|
|
|
||
|
|
BufferedImage deserialized = ((BufferedImageSerde) serde).deserialize(serialized);
|
||
|
|
int originalSize = image.getData().getDataBuffer().getSize();
|
||
|
|
assertEquals(originalSize, deserialized.getData().getDataBuffer().getSize());
|
||
|
|
for (int i = 0; i < originalSize; ++i) {
|
||
|
|
assertEquals(deserialized.getData().getDataBuffer().getElem(i),
|
||
|
|
image.getData().getDataBuffer().getElem(i));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
public void testImageToINDArray() {
|
||
|
|
INDArray data = ImageConversionUtils.makeRandomImageAsINDArray(28,28,1);
|
||
|
|
assertNotNull(data);
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
public void testMlnMnist_ImageInput() throws Exception {
|
||
|
|
|
||
|
|
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||
|
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||
|
|
|
||
|
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||
|
|
.outputSerializer(new IntegerSerde())
|
||
|
|
.inputBinaryDeserializer(new BufferedImageSerde())
|
||
|
|
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
|
||
|
|
@Override
|
||
|
|
public MultiDataSet apply(BufferedImage input) {
|
||
|
|
INDArray data = null;
|
||
|
|
try {
|
||
|
|
data = new Java2DNativeImageLoader().asMatrix(input);
|
||
|
|
data = data.reshape(1, 784);
|
||
|
|
}
|
||
|
|
catch (IOException e) {
|
||
|
|
throw new RuntimeException(e);
|
||
|
|
}
|
||
|
|
return new MultiDataSet(data, null);
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public Integer apply(INDArray... nnOutput) {
|
||
|
|
return nnOutput[0].argMax().getInt(0);
|
||
|
|
}
|
||
|
|
})
|
||
|
|
.port(PORT)
|
||
|
|
.inferenceMode(SEQUENTIAL)
|
||
|
|
.numWorkers(1)
|
||
|
|
.parallelMode(false)
|
||
|
|
.build();
|
||
|
|
|
||
|
|
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
|
||
|
|
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||
|
|
.inputBinarySerializer(new BufferedImageSerde())
|
||
|
|
.outputDeserializer(new IntegerSerde())
|
||
|
|
.build();
|
||
|
|
|
||
|
|
try {
|
||
|
|
server.start();
|
||
|
|
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||
|
|
Integer result = client.predict(image);
|
||
|
|
assertNotNull(result);
|
||
|
|
|
||
|
|
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
|
||
|
|
image = ImageIO.read(new FileInputStream(file));
|
||
|
|
result = client.predict(image);
|
||
|
|
assertEquals(new Integer(0), result);
|
||
|
|
|
||
|
|
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
|
||
|
|
image = ImageIO.read(new FileInputStream(file));
|
||
|
|
result = client.predict(image);
|
||
|
|
assertEquals(new Integer(1), result);
|
||
|
|
|
||
|
|
} catch (Exception e){
|
||
|
|
e.printStackTrace();
|
||
|
|
throw e;
|
||
|
|
} finally {
|
||
|
|
server.stop();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
public void testMlnMnist_ImageInput_Async() throws Exception {
|
||
|
|
|
||
|
|
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||
|
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||
|
|
|
||
|
|
val server = new JsonModelServer.Builder<BufferedImage, Integer>(net)
|
||
|
|
.outputSerializer(new IntegerSerde())
|
||
|
|
.inputBinaryDeserializer(new BufferedImageSerde())
|
||
|
|
.inferenceAdapter(new InferenceAdapter<BufferedImage, Integer>() {
|
||
|
|
@Override
|
||
|
|
public MultiDataSet apply(BufferedImage input) {
|
||
|
|
INDArray data = null;
|
||
|
|
try {
|
||
|
|
data = new Java2DNativeImageLoader().asMatrix(input);
|
||
|
|
data = data.reshape(1, 784);
|
||
|
|
}
|
||
|
|
catch (IOException e) {
|
||
|
|
throw new RuntimeException(e);
|
||
|
|
}
|
||
|
|
return new MultiDataSet(data, null);
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public Integer apply(INDArray... nnOutput) {
|
||
|
|
return nnOutput[0].argMax().getInt(0);
|
||
|
|
}
|
||
|
|
})
|
||
|
|
.port(PORT)
|
||
|
|
.inferenceMode(SEQUENTIAL)
|
||
|
|
.numWorkers(1)
|
||
|
|
.parallelMode(false)
|
||
|
|
.build();
|
||
|
|
|
||
|
|
val client = JsonRemoteInference.<BufferedImage, Integer>builder()
|
||
|
|
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||
|
|
.inputBinarySerializer(new BufferedImageSerde())
|
||
|
|
.outputDeserializer(new IntegerSerde())
|
||
|
|
.build();
|
||
|
|
|
||
|
|
try {
|
||
|
|
server.start();
|
||
|
|
BufferedImage[] images = new BufferedImage[3];
|
||
|
|
images[0] = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||
|
|
|
||
|
|
File file = new ClassPathResource("datavec-local/imagetest/0/b.bmp").getFile();
|
||
|
|
images[1] = ImageIO.read(new FileInputStream(file));
|
||
|
|
|
||
|
|
file = new ClassPathResource("datavec-local/imagetest/1/a.bmp").getFile();
|
||
|
|
images[2] = ImageIO.read(new FileInputStream(file));
|
||
|
|
|
||
|
|
Future<Integer>[] results = new Future[3];
|
||
|
|
for (int i = 0; i < images.length; ++i) {
|
||
|
|
results[i] = client.predictAsync(images[i]);
|
||
|
|
assertNotNull(results[i]);
|
||
|
|
}
|
||
|
|
|
||
|
|
assertNotNull(results[0].get());
|
||
|
|
assertEquals(new Integer(0), results[1].get());
|
||
|
|
assertEquals(new Integer(1), results[2].get());
|
||
|
|
|
||
|
|
} catch (Exception e){
|
||
|
|
e.printStackTrace();
|
||
|
|
throw e;
|
||
|
|
} finally {
|
||
|
|
server.stop();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@Test
|
||
|
|
public void testBinaryIn_BinaryOut() throws Exception {
|
||
|
|
|
||
|
|
val modelFile = new ClassPathResource("models/mnist/mnist-model.zip").getFile();
|
||
|
|
MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
||
|
|
|
||
|
|
val server = new JsonModelServer.Builder<BufferedImage, BufferedImage>(net)
|
||
|
|
.outputBinarySerializer(new BufferedImageSerde())
|
||
|
|
.inputBinaryDeserializer(new BufferedImageSerde())
|
||
|
|
.inferenceAdapter(new InferenceAdapter<BufferedImage, BufferedImage>() {
|
||
|
|
@Override
|
||
|
|
public MultiDataSet apply(BufferedImage input) {
|
||
|
|
INDArray data = null;
|
||
|
|
try {
|
||
|
|
data = new Java2DNativeImageLoader().asMatrix(input);
|
||
|
|
}
|
||
|
|
catch (IOException e) {
|
||
|
|
throw new RuntimeException(e);
|
||
|
|
}
|
||
|
|
return new MultiDataSet(data, null);
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public BufferedImage apply(INDArray... nnOutput) {
|
||
|
|
return ImageConversionUtils.makeRandomBufferedImage(28,28,3);
|
||
|
|
}
|
||
|
|
})
|
||
|
|
.port(PORT)
|
||
|
|
.inferenceMode(SEQUENTIAL)
|
||
|
|
.numWorkers(1)
|
||
|
|
.parallelMode(false)
|
||
|
|
.build();
|
||
|
|
|
||
|
|
val client = JsonRemoteInference.<BufferedImage, BufferedImage>builder()
|
||
|
|
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||
|
|
.inputBinarySerializer(new BufferedImageSerde())
|
||
|
|
.outputBinaryDeserializer(new BufferedImageSerde())
|
||
|
|
.build();
|
||
|
|
|
||
|
|
try {
|
||
|
|
server.start();
|
||
|
|
BufferedImage image = ImageConversionUtils.makeRandomBufferedImage(28,28,1);
|
||
|
|
BufferedImage result = client.predict(image);
|
||
|
|
assertNotNull(result);
|
||
|
|
assertEquals(28, result.getHeight());
|
||
|
|
assertEquals(28, result.getWidth());
|
||
|
|
|
||
|
|
} catch (Exception e){
|
||
|
|
e.printStackTrace();
|
||
|
|
throw e;
|
||
|
|
} finally {
|
||
|
|
server.stop();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
private static class BufferedImageSerde implements BinarySerializer<BufferedImage>, BinaryDeserializer<BufferedImage> {
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public BufferedImage deserialize(byte[] buffer) {
|
||
|
|
try {
|
||
|
|
BufferedImage img = ImageIO.read(new ByteArrayInputStream(buffer));
|
||
|
|
return img;
|
||
|
|
} catch (IOException e){
|
||
|
|
throw new RuntimeException(e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@Override
|
||
|
|
public byte[] serialize(BufferedImage image) {
|
||
|
|
try{
|
||
|
|
val baos = new ByteArrayOutputStream();
|
||
|
|
ImageIO.write(image, "bmp", baos);
|
||
|
|
byte[] bytes = baos.toByteArray();
|
||
|
|
return bytes;
|
||
|
|
} catch (IOException e){
|
||
|
|
throw new RuntimeException(e);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|