diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
index 66c639f56..1fe0d20ff 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
@@ -1,5 +1,6 @@
- 2.11.12
- 2.11
-
@@ -73,29 +69,17 @@
- com.typesafe.play
- play-java_2.11
- ${playframework.version}
-
-
- com.google.code.findbugs
- jsr305
-
-
- org.apache.tomcat
- tomcat-servlet-api
-
-
- net.jodah
- typetools
-
-
+ io.vertx
+ vertx-core
+ ${vertx.version}
+
- net.jodah
- typetools
- ${jodah.typetools.version}
+ io.vertx
+ vertx-web
+ ${vertx.version}
+
com.mashape.unirest
unirest-java
@@ -108,25 +92,16 @@
${project.version}
test
-
- com.typesafe.play
- play-json_2.11
- ${playframework.version}
-
-
- com.typesafe.play
- play-server_2.11
- ${playframework.version}
-
com.beust
jcommander
${jcommander.version}
+
- com.typesafe.play
- play-netty-server_2.11
- ${playframework.version}
+ ch.qos.logback
+ logback-classic
+ test
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
index a79b57b19..6610e75f9 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
+import io.netty.handler.codec.http.HttpResponseStatus;
+import io.vertx.core.AbstractVerticle;
+import io.vertx.core.Vertx;
+import io.vertx.ext.web.Router;
+import io.vertx.ext.web.handler.BodyHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.clustering.sptree.DataPoint;
@@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nearestneighbor.model.*;
+import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
@@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde;
-import play.BuiltInComponents;
-import play.Mode;
-import play.libs.Json;
-import play.routing.Router;
-import play.routing.RoutingDsl;
-import play.server.Server;
import java.io.File;
import java.util.*;
-import static play.mvc.Controller.request;
-import static play.mvc.Results.*;
-
/**
* A rest server for using an
* {@link VPTree} based on loading an ndarray containing
@@ -57,22 +55,33 @@ import static play.mvc.Results.*;
* @author Adam Gibson
*/
@Slf4j
-public class NearestNeighborsServer {
- @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
- private String ndarrayPath = null;
- @Parameter(names = {"--labelsPath"}, arity = 1, required = false)
- private String labelsPath = null;
- @Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
- private int port = 9000;
- @Parameter(names = {"--similarityFunction"}, arity = 1)
- private String similarityFunction = "euclidean";
- @Parameter(names = {"--invert"}, arity = 1)
- private boolean invert = false;
+public class NearestNeighborsServer extends AbstractVerticle {
- private Server server;
+ private static class RunArgs {
+ @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
+ private String ndarrayPath = null;
+ @Parameter(names = {"--labelsPath"}, arity = 1, required = false)
+ private String labelsPath = null;
+ @Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
+ private int port = 9000;
+ @Parameter(names = {"--similarityFunction"}, arity = 1)
+ private String similarityFunction = "euclidean";
+ @Parameter(names = {"--invert"}, arity = 1)
+ private boolean invert = false;
+ }
- public void runMain(String... args) throws Exception {
- JCommander jcmdr = new JCommander(this);
+ private static RunArgs instanceArgs;
+ private static NearestNeighborsServer instance;
+
+ public NearestNeighborsServer(){ }
+
+ public static NearestNeighborsServer getInstance(){
+ return instance;
+ }
+
+ public static void runMain(String... args) {
+ RunArgs r = new RunArgs();
+ JCommander jcmdr = new JCommander(r);
try {
jcmdr.parse(args);
@@ -84,7 +93,7 @@ public class NearestNeighborsServer {
//User provides invalid input -> print the usage info
jcmdr.usage();
- if (ndarrayPath == null)
+ if (r.ndarrayPath == null)
log.error("Json path parameter is missing (null)");
try {
Thread.sleep(500);
@@ -93,16 +102,20 @@ public class NearestNeighborsServer {
System.exit(1);
}
+ instanceArgs = r;
try {
- runHelper();
+ Vertx vertx = Vertx.vertx();
+ vertx.deployVerticle(NearestNeighborsServer.class.getName());
} catch (Throwable t){
log.error("Error in NearestNeighboursServer run method",t);
}
}
- protected void runHelper() throws Exception {
+ @Override
+ public void start() throws Exception {
+ instance = this;
- String[] pathArr = ndarrayPath.split(",");
+ String[] pathArr = instanceArgs.ndarrayPath.split(",");
//INDArray[] pointsArr = new INDArray[pathArr.length];
// first of all we reading shapes of saved eariler files
int rows = 0;
@@ -111,7 +124,7 @@ public class NearestNeighborsServer {
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
- Shape.size(shape, 1));
+ Shape.size(shape, 1));
if (Shape.rank(shape) != 2)
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
@@ -122,12 +135,12 @@ public class NearestNeighborsServer {
cols = Shape.size(shape, 1);
else if (cols != Shape.size(shape, 1))
throw new DL4JInvalidInputException(
- "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
+ "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
}
final List labels = new ArrayList<>();
- if (labelsPath != null) {
- String[] labelsPathArr = labelsPath.split(",");
+ if (instanceArgs.labelsPath != null) {
+ String[] labelsPathArr = instanceArgs.labelsPath.split(",");
for (int i = 0; i < labelsPathArr.length; i++) {
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
}
@@ -149,7 +162,7 @@ public class NearestNeighborsServer {
System.gc();
}
- VPTree tree = new VPTree(points, similarityFunction, invert);
+ VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
//Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret
@@ -163,40 +176,57 @@ public class NearestNeighborsServer {
System.setProperty("play.crypto.secret", base64);
}
+ Router r = Router.router(vertx);
+ r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
+ createRoutes(r, labels, tree, points);
- server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
+ vertx.createHttpServer()
+ .requestHandler(r)
+ .listen(instanceArgs.port);
}
- protected Router createRouter(VPTree tree, List labels, INDArray points, BuiltInComponents builtInComponents){
- RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
- //return the host information for a given id
- routingDsl.POST("/knn").routingTo(request -> {
+ private void createRoutes(Router r, List labels, VPTree tree, INDArray points){
+
+ r.post("/knn").handler(rc -> {
try {
- NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
+ String json = rc.getBodyAsJson().encode();
+ NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
+
NearestNeighbor nearestNeighbor =
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
- if (record == null)
- return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
+ if (record == null) {
+ rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
+ .putHeader("content-type", "application/json")
+ .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
+ return;
+ }
- NearestNeighborsResults results =
- NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
-
-
- return ok(Json.toJson(results));
+ NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
+ rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
+ .putHeader("content-type", "application/json")
+ .end(JsonMappers.getMapper().writeValueAsString(results));
+ return;
} catch (Throwable e) {
log.error("Error in POST /knn",e);
e.printStackTrace();
- return internalServerError(e.getMessage());
+ rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
+ .end("Error parsing request - " + e.getMessage());
+ return;
}
});
- routingDsl.POST("/knnnew").routingTo(request -> {
+ r.post("/knnnew").handler(rc -> {
try {
- Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
- if (record == null)
- return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
+ String json = rc.getBodyAsJson().encode();
+ Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
+ if (record == null) {
+ rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
+ .putHeader("content-type", "application/json")
+ .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
+ return;
+ }
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
List results;
@@ -214,9 +244,10 @@ public class NearestNeighborsServer {
}
if (results.size() != distances.size()) {
- return internalServerError(
- String.format("results.size == %d != %d == distances.size",
- results.size(), distances.size()));
+ rc.response()
+ .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
+ .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
+ return;
}
List nnResult = new ArrayList<>();
@@ -228,30 +259,29 @@ public class NearestNeighborsServer {
}
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
- return ok(Json.toJson(results2));
-
+ String j = JsonMappers.getMapper().writeValueAsString(results2);
+ rc.response()
+ .putHeader("content-type", "application/json")
+ .end(j);
} catch (Throwable e) {
log.error("Error in POST /knnnew",e);
e.printStackTrace();
- return internalServerError(e.getMessage());
+ rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
+ .end("Error parsing request - " + e.getMessage());
+ return;
}
});
-
- return routingDsl.build();
}
/**
* Stop the server
*/
- public void stop() {
- if (server != null) {
- log.info("Attempting to stop server");
- server.stop();
- }
+ public void stop() throws Exception {
+ super.stop();
}
public static void main(String[] args) throws Exception {
- new NearestNeighborsServer().runMain(args);
+ runMain(args);
}
}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java
index 9f8fd7241..b42c407e5 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
public TemporaryFolder testDir = new TemporaryFolder();
@Test
- //@Ignore("AB 2019/05/21 - Failing - Issue #7657")
public void testNearestNeighbor() {
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
INDArray arr = Nd4j.create(data);
@@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
File writeToTmp = testDir.newFile();
writeToTmp.deleteOnExit();
BinarySerde.writeArrayToDisk(rand, writeToTmp);
- NearestNeighborsServer server = new NearestNeighborsServer();
- server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
+ NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
String.valueOf(localPort));
+ Thread.sleep(3000);
+
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
assertEquals(5, result.getResults().size());
- server.stop();
+ NearestNeighborsServer.getInstance().stop();
}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml
new file mode 100644
index 000000000..7953c2712
--- /dev/null
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml
@@ -0,0 +1,42 @@
+
+
+
+
+
+ logs/application.log
+
+ %date - [%level] - from %logger in %thread
+ %n%message%n%xException%n
+
+
+
+
+
+ %logger{15} - %message%n%xException{5}
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
index d820dd6b7..720e7a5f3 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
@@ -1,5 +1,6 @@