[WIP] DL4J nearestneighbors-sever: Play ->Vertx (#79)

* Switch Nearest neighbors server implementation from Play to Vertx

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* No more scala version suffix for nearest neighbor server

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* logback.xml fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Header tweaks

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-25 18:46:34 +11:00 committed by GitHub
parent 7f90930e7a
commit 0e3fcdc24d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 158 additions and 109 deletions

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
@ -23,16 +24,11 @@
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-nearestneighbor-server_2.11</artifactId>
<artifactId>deeplearning4j-nearestneighbor-server</artifactId>
<packaging>jar</packaging>
<name>deeplearning4j-nearestneighbor-server</name>
<properties>
<!-- Default scala versions, may be overwritten by build profiles -->
<scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version>
</properties>
<build>
<pluginManagement>
<plugins>
@ -73,29 +69,17 @@
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-java_2.11</artifactId>
<version>${playframework.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.tomcat</groupId>
<artifactId>tomcat-servlet-api</artifactId>
</exclusion>
<exclusion>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
</exclusion>
</exclusions>
<groupId>io.vertx</groupId>
<artifactId>vertx-core</artifactId>
<version>${vertx.version}</version>
</dependency>
<dependency>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
<version>${jodah.typetools.version}</version>
<groupId>io.vertx</groupId>
<artifactId>vertx-web</artifactId>
<version>${vertx.version}</version>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
@ -108,25 +92,16 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-json_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-server_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
<version>${jcommander.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-netty-server_2.11</artifactId>
<version>${playframework.version}</version>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Vertx;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.clustering.sptree.DataPoint;
@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nearestneighbor.model.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde;
import play.BuiltInComponents;
import play.Mode;
import play.libs.Json;
import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
import java.io.File;
import java.util.*;
import static play.mvc.Controller.request;
import static play.mvc.Results.*;
/**
* A rest server for using an
* {@link VPTree} based on loading an ndarray containing
@ -57,22 +55,33 @@ import static play.mvc.Results.*;
* @author Adam Gibson
*/
@Slf4j
public class NearestNeighborsServer {
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
private String ndarrayPath = null;
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
private String labelsPath = null;
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
private int port = 9000;
@Parameter(names = {"--similarityFunction"}, arity = 1)
private String similarityFunction = "euclidean";
@Parameter(names = {"--invert"}, arity = 1)
private boolean invert = false;
public class NearestNeighborsServer extends AbstractVerticle {
private Server server;
private static class RunArgs {
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
private String ndarrayPath = null;
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
private String labelsPath = null;
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
private int port = 9000;
@Parameter(names = {"--similarityFunction"}, arity = 1)
private String similarityFunction = "euclidean";
@Parameter(names = {"--invert"}, arity = 1)
private boolean invert = false;
}
public void runMain(String... args) throws Exception {
JCommander jcmdr = new JCommander(this);
private static RunArgs instanceArgs;
private static NearestNeighborsServer instance;
public NearestNeighborsServer(){ }
public static NearestNeighborsServer getInstance(){
return instance;
}
public static void runMain(String... args) {
RunArgs r = new RunArgs();
JCommander jcmdr = new JCommander(r);
try {
jcmdr.parse(args);
@ -84,7 +93,7 @@ public class NearestNeighborsServer {
//User provides invalid input -> print the usage info
jcmdr.usage();
if (ndarrayPath == null)
if (r.ndarrayPath == null)
log.error("Json path parameter is missing (null)");
try {
Thread.sleep(500);
@ -93,16 +102,20 @@ public class NearestNeighborsServer {
System.exit(1);
}
instanceArgs = r;
try {
runHelper();
Vertx vertx = Vertx.vertx();
vertx.deployVerticle(NearestNeighborsServer.class.getName());
} catch (Throwable t){
log.error("Error in NearestNeighboursServer run method",t);
}
}
protected void runHelper() throws Exception {
@Override
public void start() throws Exception {
instance = this;
String[] pathArr = ndarrayPath.split(",");
String[] pathArr = instanceArgs.ndarrayPath.split(",");
//INDArray[] pointsArr = new INDArray[pathArr.length];
// first of all we reading shapes of saved eariler files
int rows = 0;
@ -111,7 +124,7 @@ public class NearestNeighborsServer {
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
Shape.size(shape, 1));
Shape.size(shape, 1));
if (Shape.rank(shape) != 2)
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
@ -122,12 +135,12 @@ public class NearestNeighborsServer {
cols = Shape.size(shape, 1);
else if (cols != Shape.size(shape, 1))
throw new DL4JInvalidInputException(
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
}
final List<String> labels = new ArrayList<>();
if (labelsPath != null) {
String[] labelsPathArr = labelsPath.split(",");
if (instanceArgs.labelsPath != null) {
String[] labelsPathArr = instanceArgs.labelsPath.split(",");
for (int i = 0; i < labelsPathArr.length; i++) {
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
}
@ -149,7 +162,7 @@ public class NearestNeighborsServer {
System.gc();
}
VPTree tree = new VPTree(points, similarityFunction, invert);
VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
//Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret
@ -163,40 +176,57 @@ public class NearestNeighborsServer {
System.setProperty("play.crypto.secret", base64);
}
Router r = Router.router(vertx);
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
createRoutes(r, labels, tree, points);
server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
vertx.createHttpServer()
.requestHandler(r)
.listen(instanceArgs.port);
}
protected Router createRouter(VPTree tree, List<String> labels, INDArray points, BuiltInComponents builtInComponents){
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
//return the host information for a given id
routingDsl.POST("/knn").routingTo(request -> {
private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
r.post("/knn").handler(rc -> {
try {
NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
String json = rc.getBodyAsJson().encode();
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
NearestNeighbor nearestNeighbor =
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
if (record == null)
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
if (record == null) {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
NearestNeighborsResults results =
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
return ok(Json.toJson(results));
NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(results));
return;
} catch (Throwable e) {
log.error("Error in POST /knn",e);
e.printStackTrace();
return internalServerError(e.getMessage());
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
}
});
routingDsl.POST("/knnnew").routingTo(request -> {
r.post("/knnnew").handler(rc -> {
try {
Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
if (record == null)
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
String json = rc.getBodyAsJson().encode();
Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
if (record == null) {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
List<DataPoint> results;
@ -214,9 +244,10 @@ public class NearestNeighborsServer {
}
if (results.size() != distances.size()) {
return internalServerError(
String.format("results.size == %d != %d == distances.size",
results.size(), distances.size()));
rc.response()
.setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
return;
}
List<NearestNeighborsResult> nnResult = new ArrayList<>();
@ -228,30 +259,29 @@ public class NearestNeighborsServer {
}
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
return ok(Json.toJson(results2));
String j = JsonMappers.getMapper().writeValueAsString(results2);
rc.response()
.putHeader("content-type", "application/json")
.end(j);
} catch (Throwable e) {
log.error("Error in POST /knnnew",e);
e.printStackTrace();
return internalServerError(e.getMessage());
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
}
});
return routingDsl.build();
}
/**
* Stop the server
*/
public void stop() {
if (server != null) {
log.info("Attempting to stop server");
server.stop();
}
public void stop() throws Exception {
super.stop();
}
public static void main(String[] args) throws Exception {
new NearestNeighborsServer().runMain(args);
runMain(args);
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
public TemporaryFolder testDir = new TemporaryFolder();
@Test
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
public void testNearestNeighbor() {
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
INDArray arr = Nd4j.create(data);
@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
File writeToTmp = testDir.newFile();
writeToTmp.deleteOnExit();
BinarySerde.writeArrayToDisk(rand, writeToTmp);
NearestNeighborsServer server = new NearestNeighborsServer();
server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
String.valueOf(localPort));
Thread.sleep(3000);
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
assertEquals(5, result.getResults().size());
server.stop();
NearestNeighborsServer.getInstance().stop();
}

View File

@ -0,0 +1,42 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2019 Konduit K.K.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.deeplearning4j" level="INFO" />
<logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -1,5 +1,6 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~ Copyright (c) 2019 Konduit K.K.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at