diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index 66c639f56..1fe0d20ff 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -1,5 +1,6 @@ - 2.11.12 - 2.11 - @@ -73,29 +69,17 @@ - com.typesafe.play - play-java_2.11 - ${playframework.version} - - - com.google.code.findbugs - jsr305 - - - org.apache.tomcat - tomcat-servlet-api - - - net.jodah - typetools - - + io.vertx + vertx-core + ${vertx.version} + - net.jodah - typetools - ${jodah.typetools.version} + io.vertx + vertx-web + ${vertx.version} + com.mashape.unirest unirest-java @@ -108,25 +92,16 @@ ${project.version} test - - com.typesafe.play - play-json_2.11 - ${playframework.version} - - - com.typesafe.play - play-server_2.11 - ${playframework.version} - com.beust jcommander ${jcommander.version} + - com.typesafe.play - play-netty-server_2.11 - ${playframework.version} + ch.qos.logback + logback-classic + test diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java index a79b57b19..6610e75f9 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.vertx.core.AbstractVerticle; +import io.vertx.core.Vertx; +import io.vertx.ext.web.Router; +import io.vertx.ext.web.handler.BodyHandler; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.clustering.sptree.DataPoint; @@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree; import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nearestneighbor.model.*; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; @@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.binary.BinarySerde; -import play.BuiltInComponents; -import play.Mode; -import play.libs.Json; -import play.routing.Router; -import play.routing.RoutingDsl; -import play.server.Server; import java.io.File; import java.util.*; -import static play.mvc.Controller.request; -import static play.mvc.Results.*; - /** * A rest server for using an * {@link VPTree} based on loading an ndarray containing @@ -57,22 +55,33 @@ import static play.mvc.Results.*; * @author Adam Gibson */ @Slf4j -public class NearestNeighborsServer { - @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) - private String ndarrayPath = null; - @Parameter(names = {"--labelsPath"}, arity = 1, required = false) - private String labelsPath = null; - @Parameter(names = {"--nearestNeighborsPort"}, arity = 1) - private int port = 9000; - @Parameter(names = {"--similarityFunction"}, arity = 1) - private String similarityFunction = "euclidean"; - @Parameter(names = {"--invert"}, arity = 1) - private boolean invert = false; +public class NearestNeighborsServer extends AbstractVerticle { - private Server server; + private static class RunArgs { + @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) + private String ndarrayPath = null; + @Parameter(names = {"--labelsPath"}, arity = 1, required = false) + private String labelsPath = null; + @Parameter(names = {"--nearestNeighborsPort"}, arity = 1) + private int port = 9000; + @Parameter(names = {"--similarityFunction"}, arity = 1) + private String similarityFunction = "euclidean"; + @Parameter(names = {"--invert"}, arity = 1) + private boolean invert = false; + } - public void runMain(String... args) throws Exception { - JCommander jcmdr = new JCommander(this); + private static RunArgs instanceArgs; + private static NearestNeighborsServer instance; + + public NearestNeighborsServer(){ } + + public static NearestNeighborsServer getInstance(){ + return instance; + } + + public static void runMain(String... args) { + RunArgs r = new RunArgs(); + JCommander jcmdr = new JCommander(r); try { jcmdr.parse(args); @@ -84,7 +93,7 @@ public class NearestNeighborsServer { //User provides invalid input -> print the usage info jcmdr.usage(); - if (ndarrayPath == null) + if (r.ndarrayPath == null) log.error("Json path parameter is missing (null)"); try { Thread.sleep(500); @@ -93,16 +102,20 @@ public class NearestNeighborsServer { System.exit(1); } + instanceArgs = r; try { - runHelper(); + Vertx vertx = Vertx.vertx(); + vertx.deployVerticle(NearestNeighborsServer.class.getName()); } catch (Throwable t){ log.error("Error in NearestNeighboursServer run method",t); } } - protected void runHelper() throws Exception { + @Override + public void start() throws Exception { + instance = this; - String[] pathArr = ndarrayPath.split(","); + String[] pathArr = instanceArgs.ndarrayPath.split(","); //INDArray[] pointsArr = new INDArray[pathArr.length]; // first of all we reading shapes of saved eariler files int rows = 0; @@ -111,7 +124,7 @@ public class NearestNeighborsServer { DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i])); log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0), - Shape.size(shape, 1)); + Shape.size(shape, 1)); if (Shape.rank(shape) != 2) throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks"); @@ -122,12 +135,12 @@ public class NearestNeighborsServer { cols = Shape.size(shape, 1); else if (cols != Shape.size(shape, 1)) throw new DL4JInvalidInputException( - "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); + "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); } final List labels = new ArrayList<>(); - if (labelsPath != null) { - String[] labelsPathArr = labelsPath.split(","); + if (instanceArgs.labelsPath != null) { + String[] labelsPathArr = instanceArgs.labelsPath.split(","); for (int i = 0; i < labelsPathArr.length; i++) { labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); } @@ -149,7 +162,7 @@ public class NearestNeighborsServer { System.gc(); } - VPTree tree = new VPTree(points, similarityFunction, invert); + VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert); //Set play secret key, if required //http://www.playframework.com/documentation/latest/ApplicationSecret @@ -163,40 +176,57 @@ public class NearestNeighborsServer { System.setProperty("play.crypto.secret", base64); } + Router r = Router.router(vertx); + r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all + createRoutes(r, labels, tree, points); - server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b)); + vertx.createHttpServer() + .requestHandler(r) + .listen(instanceArgs.port); } - protected Router createRouter(VPTree tree, List labels, INDArray points, BuiltInComponents builtInComponents){ - RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); - //return the host information for a given id - routingDsl.POST("/knn").routingTo(request -> { + private void createRoutes(Router r, List labels, VPTree tree, INDArray points){ + + r.post("/knn").handler(rc -> { try { - NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class); + String json = rc.getBodyAsJson().encode(); + NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class); + NearestNeighbor nearestNeighbor = NearestNeighbor.builder().points(points).record(record).tree(tree).build(); - if (record == null) - return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); + if (record == null) { + rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) + .putHeader("content-type", "application/json") + .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); + return; + } - NearestNeighborsResults results = - NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); - - - return ok(Json.toJson(results)); + NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); + rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) + .putHeader("content-type", "application/json") + .end(JsonMappers.getMapper().writeValueAsString(results)); + return; } catch (Throwable e) { log.error("Error in POST /knn",e); e.printStackTrace(); - return internalServerError(e.getMessage()); + rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) + .end("Error parsing request - " + e.getMessage()); + return; } }); - routingDsl.POST("/knnnew").routingTo(request -> { + r.post("/knnnew").handler(rc -> { try { - Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class); - if (record == null) - return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); + String json = rc.getBodyAsJson().encode(); + Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class); + if (record == null) { + rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) + .putHeader("content-type", "application/json") + .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); + return; + } INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); List results; @@ -214,9 +244,10 @@ public class NearestNeighborsServer { } if (results.size() != distances.size()) { - return internalServerError( - String.format("results.size == %d != %d == distances.size", - results.size(), distances.size())); + rc.response() + .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) + .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size())); + return; } List nnResult = new ArrayList<>(); @@ -228,30 +259,29 @@ public class NearestNeighborsServer { } NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build(); - return ok(Json.toJson(results2)); - + String j = JsonMappers.getMapper().writeValueAsString(results2); + rc.response() + .putHeader("content-type", "application/json") + .end(j); } catch (Throwable e) { log.error("Error in POST /knnnew",e); e.printStackTrace(); - return internalServerError(e.getMessage()); + rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) + .end("Error parsing request - " + e.getMessage()); + return; } }); - - return routingDsl.build(); } /** * Stop the server */ - public void stop() { - if (server != null) { - log.info("Attempting to stop server"); - server.stop(); - } + public void stop() throws Exception { + super.stop(); } public static void main(String[] args) throws Exception { - new NearestNeighborsServer().runMain(args); + runMain(args); } } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java index 9f8fd7241..b42c407e5 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest { public TemporaryFolder testDir = new TemporaryFolder(); @Test - //@Ignore("AB 2019/05/21 - Failing - Issue #7657") public void testNearestNeighbor() { double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; INDArray arr = Nd4j.create(data); @@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest { File writeToTmp = testDir.newFile(); writeToTmp.deleteOnExit(); BinarySerde.writeArrayToDisk(rand, writeToTmp); - NearestNeighborsServer server = new NearestNeighborsServer(); - server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", + NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", String.valueOf(localPort)); + Thread.sleep(3000); + NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); assertEquals(5, result.getResults().size()); - server.stop(); + NearestNeighborsServer.getInstance().stop(); } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml new file mode 100644 index 000000000..7953c2712 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml @@ -0,0 +1,42 @@ + + + + + + logs/application.log + + %date - [%level] - from %logger in %thread + %n%message%n%xException%n + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml index d820dd6b7..720e7a5f3 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml @@ -1,5 +1,6 @@