[WIP] DL4J nearestneighbors-sever: Play ->Vertx (#79)

* Switch Nearest neighbors server implementation from Play to Vertx

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* No more scala version suffix for nearest neighbor server

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* logback.xml fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Header tweaks

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-25 18:46:34 +11:00 committed by GitHub
parent 7f90930e7a
commit 0e3fcdc24d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 158 additions and 109 deletions

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc. ~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~ ~
~ This program and the accompanying materials are made available under the ~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at ~ terms of the Apache License, Version 2.0 which is available at
@ -23,16 +24,11 @@
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-nearestneighbor-server_2.11</artifactId> <artifactId>deeplearning4j-nearestneighbor-server</artifactId>
<packaging>jar</packaging> <packaging>jar</packaging>
<name>deeplearning4j-nearestneighbor-server</name> <name>deeplearning4j-nearestneighbor-server</name>
<properties>
<!-- Default scala versions, may be overwritten by build profiles -->
<scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version>
</properties>
<build> <build>
<pluginManagement> <pluginManagement>
<plugins> <plugins>
@ -73,29 +69,17 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>io.vertx</groupId>
<artifactId>play-java_2.11</artifactId> <artifactId>vertx-core</artifactId>
<version>${playframework.version}</version> <version>${vertx.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat</groupId>
<artifactId>tomcat-servlet-api</artifactId>
</exclusion>
<exclusion>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>
<groupId>net.jodah</groupId> <groupId>io.vertx</groupId>
<artifactId>typetools</artifactId> <artifactId>vertx-web</artifactId>
<version>${jodah.typetools.version}</version> <version>${vertx.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.mashape.unirest</groupId> <groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId> <artifactId>unirest-java</artifactId>
@ -108,25 +92,16 @@
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-json_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-server_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.beust</groupId> <groupId>com.beust</groupId>
<artifactId>jcommander</artifactId> <artifactId>jcommander</artifactId>
<version>${jcommander.version}</version> <version>${jcommander.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>play-netty-server_2.11</artifactId> <artifactId>logback-classic</artifactId>
<version>${playframework.version}</version> <scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException; import com.beust.jcommander.ParameterException;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Vertx;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.deeplearning4j.clustering.sptree.DataPoint; import org.deeplearning4j.clustering.sptree.DataPoint;
@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nearestneighbor.model.*; import org.deeplearning4j.nearestneighbor.model.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde; import org.nd4j.serde.binary.BinarySerde;
import play.BuiltInComponents;
import play.Mode;
import play.libs.Json;
import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
import java.io.File; import java.io.File;
import java.util.*; import java.util.*;
import static play.mvc.Controller.request;
import static play.mvc.Results.*;
/** /**
* A rest server for using an * A rest server for using an
* {@link VPTree} based on loading an ndarray containing * {@link VPTree} based on loading an ndarray containing
@ -57,22 +55,33 @@ import static play.mvc.Results.*;
* @author Adam Gibson * @author Adam Gibson
*/ */
@Slf4j @Slf4j
public class NearestNeighborsServer { public class NearestNeighborsServer extends AbstractVerticle {
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
private String ndarrayPath = null;
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
private String labelsPath = null;
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
private int port = 9000;
@Parameter(names = {"--similarityFunction"}, arity = 1)
private String similarityFunction = "euclidean";
@Parameter(names = {"--invert"}, arity = 1)
private boolean invert = false;
private Server server; private static class RunArgs {
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
private String ndarrayPath = null;
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
private String labelsPath = null;
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
private int port = 9000;
@Parameter(names = {"--similarityFunction"}, arity = 1)
private String similarityFunction = "euclidean";
@Parameter(names = {"--invert"}, arity = 1)
private boolean invert = false;
}
public void runMain(String... args) throws Exception { private static RunArgs instanceArgs;
JCommander jcmdr = new JCommander(this); private static NearestNeighborsServer instance;
public NearestNeighborsServer(){ }
public static NearestNeighborsServer getInstance(){
return instance;
}
public static void runMain(String... args) {
RunArgs r = new RunArgs();
JCommander jcmdr = new JCommander(r);
try { try {
jcmdr.parse(args); jcmdr.parse(args);
@ -84,7 +93,7 @@ public class NearestNeighborsServer {
//User provides invalid input -> print the usage info //User provides invalid input -> print the usage info
jcmdr.usage(); jcmdr.usage();
if (ndarrayPath == null) if (r.ndarrayPath == null)
log.error("Json path parameter is missing (null)"); log.error("Json path parameter is missing (null)");
try { try {
Thread.sleep(500); Thread.sleep(500);
@ -93,16 +102,20 @@ public class NearestNeighborsServer {
System.exit(1); System.exit(1);
} }
instanceArgs = r;
try { try {
runHelper(); Vertx vertx = Vertx.vertx();
vertx.deployVerticle(NearestNeighborsServer.class.getName());
} catch (Throwable t){ } catch (Throwable t){
log.error("Error in NearestNeighboursServer run method",t); log.error("Error in NearestNeighboursServer run method",t);
} }
} }
protected void runHelper() throws Exception { @Override
public void start() throws Exception {
instance = this;
String[] pathArr = ndarrayPath.split(","); String[] pathArr = instanceArgs.ndarrayPath.split(",");
//INDArray[] pointsArr = new INDArray[pathArr.length]; //INDArray[] pointsArr = new INDArray[pathArr.length];
// first of all we reading shapes of saved eariler files // first of all we reading shapes of saved eariler files
int rows = 0; int rows = 0;
@ -111,7 +124,7 @@ public class NearestNeighborsServer {
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i])); DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0), log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
Shape.size(shape, 1)); Shape.size(shape, 1));
if (Shape.rank(shape) != 2) if (Shape.rank(shape) != 2)
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks"); throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
@ -122,12 +135,12 @@ public class NearestNeighborsServer {
cols = Shape.size(shape, 1); cols = Shape.size(shape, 1);
else if (cols != Shape.size(shape, 1)) else if (cols != Shape.size(shape, 1))
throw new DL4JInvalidInputException( throw new DL4JInvalidInputException(
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
} }
final List<String> labels = new ArrayList<>(); final List<String> labels = new ArrayList<>();
if (labelsPath != null) { if (instanceArgs.labelsPath != null) {
String[] labelsPathArr = labelsPath.split(","); String[] labelsPathArr = instanceArgs.labelsPath.split(",");
for (int i = 0; i < labelsPathArr.length; i++) { for (int i = 0; i < labelsPathArr.length; i++) {
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
} }
@ -149,7 +162,7 @@ public class NearestNeighborsServer {
System.gc(); System.gc();
} }
VPTree tree = new VPTree(points, similarityFunction, invert); VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
//Set play secret key, if required //Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret //http://www.playframework.com/documentation/latest/ApplicationSecret
@ -163,40 +176,57 @@ public class NearestNeighborsServer {
System.setProperty("play.crypto.secret", base64); System.setProperty("play.crypto.secret", base64);
} }
Router r = Router.router(vertx);
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
createRoutes(r, labels, tree, points);
server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b)); vertx.createHttpServer()
.requestHandler(r)
.listen(instanceArgs.port);
} }
protected Router createRouter(VPTree tree, List<String> labels, INDArray points, BuiltInComponents builtInComponents){ private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
//return the host information for a given id r.post("/knn").handler(rc -> {
routingDsl.POST("/knn").routingTo(request -> {
try { try {
NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class); String json = rc.getBodyAsJson().encode();
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
NearestNeighbor nearestNeighbor = NearestNeighbor nearestNeighbor =
NearestNeighbor.builder().points(points).record(record).tree(tree).build(); NearestNeighbor.builder().points(points).record(record).tree(tree).build();
if (record == null) if (record == null) {
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
NearestNeighborsResults results = NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
return ok(Json.toJson(results));
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(results));
return;
} catch (Throwable e) { } catch (Throwable e) {
log.error("Error in POST /knn",e); log.error("Error in POST /knn",e);
e.printStackTrace(); e.printStackTrace();
return internalServerError(e.getMessage()); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
} }
}); });
routingDsl.POST("/knnnew").routingTo(request -> { r.post("/knnnew").handler(rc -> {
try { try {
Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class); String json = rc.getBodyAsJson().encode();
if (record == null) Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); if (record == null) {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
List<DataPoint> results; List<DataPoint> results;
@ -214,9 +244,10 @@ public class NearestNeighborsServer {
} }
if (results.size() != distances.size()) { if (results.size() != distances.size()) {
return internalServerError( rc.response()
String.format("results.size == %d != %d == distances.size", .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
results.size(), distances.size())); .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
return;
} }
List<NearestNeighborsResult> nnResult = new ArrayList<>(); List<NearestNeighborsResult> nnResult = new ArrayList<>();
@ -228,30 +259,29 @@ public class NearestNeighborsServer {
} }
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build(); NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
return ok(Json.toJson(results2)); String j = JsonMappers.getMapper().writeValueAsString(results2);
rc.response()
.putHeader("content-type", "application/json")
.end(j);
} catch (Throwable e) { } catch (Throwable e) {
log.error("Error in POST /knnnew",e); log.error("Error in POST /knnnew",e);
e.printStackTrace(); e.printStackTrace();
return internalServerError(e.getMessage()); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
} }
}); });
return routingDsl.build();
} }
/** /**
* Stop the server * Stop the server
*/ */
public void stop() { public void stop() throws Exception {
if (server != null) { super.stop();
log.info("Attempting to stop server");
server.stop();
}
} }
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
new NearestNeighborsServer().runMain(args); runMain(args);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@Test @Test
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
public void testNearestNeighbor() { public void testNearestNeighbor() {
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
INDArray arr = Nd4j.create(data); INDArray arr = Nd4j.create(data);
@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
File writeToTmp = testDir.newFile(); File writeToTmp = testDir.newFile();
writeToTmp.deleteOnExit(); writeToTmp.deleteOnExit();
BinarySerde.writeArrayToDisk(rand, writeToTmp); BinarySerde.writeArrayToDisk(rand, writeToTmp);
NearestNeighborsServer server = new NearestNeighborsServer(); NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
String.valueOf(localPort)); String.valueOf(localPort));
Thread.sleep(3000);
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
assertEquals(5, result.getResults().size()); assertEquals(5, result.getResults().size());
server.stop(); NearestNeighborsServer.getInstance().stop();
} }

View File

@ -0,0 +1,42 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2019 Konduit K.K.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.deeplearning4j" level="INFO" />
<logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc. ~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~ ~
~ This program and the accompanying materials are made available under the ~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at ~ terms of the Apache License, Version 2.0 which is available at