[WIP] DL4J nearestneighbors-sever: Play ->Vertx (#79)
* Switch Nearest neighbors server implementation from Play to Vertx Signed-off-by: AlexDBlack <blacka101@gmail.com> * No more scala version suffix for nearest neighbor server Signed-off-by: AlexDBlack <blacka101@gmail.com> * logback.xml fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Header tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
7f90930e7a
commit
0e3fcdc24d
|
@ -1,5 +1,6 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~ Copyright (c) 2019 Konduit K.K.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -23,16 +24,11 @@
|
|||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>deeplearning4j-nearestneighbor-server_2.11</artifactId>
|
||||
<artifactId>deeplearning4j-nearestneighbor-server</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>deeplearning4j-nearestneighbor-server</name>
|
||||
|
||||
<properties>
|
||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
||||
<scala.version>2.11.12</scala.version>
|
||||
<scala.binary.version>2.11</scala.binary.version>
|
||||
</properties>
|
||||
<build>
|
||||
<pluginManagement>
|
||||
<plugins>
|
||||
|
@ -73,29 +69,17 @@
|
|||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-java_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.apache.tomcat</groupId>
|
||||
<artifactId>tomcat-servlet-api</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>net.jodah</groupId>
|
||||
<artifactId>typetools</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
<groupId>io.vertx</groupId>
|
||||
<artifactId>vertx-core</artifactId>
|
||||
<version>${vertx.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>net.jodah</groupId>
|
||||
<artifactId>typetools</artifactId>
|
||||
<version>${jodah.typetools.version}</version>
|
||||
<groupId>io.vertx</groupId>
|
||||
<artifactId>vertx-web</artifactId>
|
||||
<version>${vertx.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
|
@ -108,25 +92,16 @@
|
|||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-json_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-server_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
<version>${jcommander.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-netty-server_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
|
|||
import com.beust.jcommander.JCommander;
|
||||
import com.beust.jcommander.Parameter;
|
||||
import com.beust.jcommander.ParameterException;
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import io.vertx.core.AbstractVerticle;
|
||||
import io.vertx.core.Vertx;
|
||||
import io.vertx.ext.web.Router;
|
||||
import io.vertx.ext.web.handler.BodyHandler;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||
|
@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
|
|||
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nearestneighbor.model.*;
|
||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
|
@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
import org.nd4j.serde.binary.BinarySerde;
|
||||
import play.BuiltInComponents;
|
||||
import play.Mode;
|
||||
import play.libs.Json;
|
||||
import play.routing.Router;
|
||||
import play.routing.RoutingDsl;
|
||||
import play.server.Server;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
|
||||
import static play.mvc.Controller.request;
|
||||
import static play.mvc.Results.*;
|
||||
|
||||
/**
|
||||
* A rest server for using an
|
||||
* {@link VPTree} based on loading an ndarray containing
|
||||
|
@ -57,22 +55,33 @@ import static play.mvc.Results.*;
|
|||
* @author Adam Gibson
|
||||
*/
|
||||
@Slf4j
|
||||
public class NearestNeighborsServer {
|
||||
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
||||
private String ndarrayPath = null;
|
||||
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
||||
private String labelsPath = null;
|
||||
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
||||
private int port = 9000;
|
||||
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
||||
private String similarityFunction = "euclidean";
|
||||
@Parameter(names = {"--invert"}, arity = 1)
|
||||
private boolean invert = false;
|
||||
public class NearestNeighborsServer extends AbstractVerticle {
|
||||
|
||||
private Server server;
|
||||
private static class RunArgs {
|
||||
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
||||
private String ndarrayPath = null;
|
||||
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
||||
private String labelsPath = null;
|
||||
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
||||
private int port = 9000;
|
||||
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
||||
private String similarityFunction = "euclidean";
|
||||
@Parameter(names = {"--invert"}, arity = 1)
|
||||
private boolean invert = false;
|
||||
}
|
||||
|
||||
public void runMain(String... args) throws Exception {
|
||||
JCommander jcmdr = new JCommander(this);
|
||||
private static RunArgs instanceArgs;
|
||||
private static NearestNeighborsServer instance;
|
||||
|
||||
public NearestNeighborsServer(){ }
|
||||
|
||||
public static NearestNeighborsServer getInstance(){
|
||||
return instance;
|
||||
}
|
||||
|
||||
public static void runMain(String... args) {
|
||||
RunArgs r = new RunArgs();
|
||||
JCommander jcmdr = new JCommander(r);
|
||||
|
||||
try {
|
||||
jcmdr.parse(args);
|
||||
|
@ -84,7 +93,7 @@ public class NearestNeighborsServer {
|
|||
|
||||
//User provides invalid input -> print the usage info
|
||||
jcmdr.usage();
|
||||
if (ndarrayPath == null)
|
||||
if (r.ndarrayPath == null)
|
||||
log.error("Json path parameter is missing (null)");
|
||||
try {
|
||||
Thread.sleep(500);
|
||||
|
@ -93,16 +102,20 @@ public class NearestNeighborsServer {
|
|||
System.exit(1);
|
||||
}
|
||||
|
||||
instanceArgs = r;
|
||||
try {
|
||||
runHelper();
|
||||
Vertx vertx = Vertx.vertx();
|
||||
vertx.deployVerticle(NearestNeighborsServer.class.getName());
|
||||
} catch (Throwable t){
|
||||
log.error("Error in NearestNeighboursServer run method",t);
|
||||
}
|
||||
}
|
||||
|
||||
protected void runHelper() throws Exception {
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
instance = this;
|
||||
|
||||
String[] pathArr = ndarrayPath.split(",");
|
||||
String[] pathArr = instanceArgs.ndarrayPath.split(",");
|
||||
//INDArray[] pointsArr = new INDArray[pathArr.length];
|
||||
// first of all we reading shapes of saved eariler files
|
||||
int rows = 0;
|
||||
|
@ -111,7 +124,7 @@ public class NearestNeighborsServer {
|
|||
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
|
||||
|
||||
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
|
||||
Shape.size(shape, 1));
|
||||
Shape.size(shape, 1));
|
||||
|
||||
if (Shape.rank(shape) != 2)
|
||||
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
|
||||
|
@ -122,12 +135,12 @@ public class NearestNeighborsServer {
|
|||
cols = Shape.size(shape, 1);
|
||||
else if (cols != Shape.size(shape, 1))
|
||||
throw new DL4JInvalidInputException(
|
||||
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
||||
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
||||
}
|
||||
|
||||
final List<String> labels = new ArrayList<>();
|
||||
if (labelsPath != null) {
|
||||
String[] labelsPathArr = labelsPath.split(",");
|
||||
if (instanceArgs.labelsPath != null) {
|
||||
String[] labelsPathArr = instanceArgs.labelsPath.split(",");
|
||||
for (int i = 0; i < labelsPathArr.length; i++) {
|
||||
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
|
||||
}
|
||||
|
@ -149,7 +162,7 @@ public class NearestNeighborsServer {
|
|||
System.gc();
|
||||
}
|
||||
|
||||
VPTree tree = new VPTree(points, similarityFunction, invert);
|
||||
VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
|
||||
|
||||
//Set play secret key, if required
|
||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||
|
@ -163,40 +176,57 @@ public class NearestNeighborsServer {
|
|||
System.setProperty("play.crypto.secret", base64);
|
||||
}
|
||||
|
||||
Router r = Router.router(vertx);
|
||||
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
|
||||
createRoutes(r, labels, tree, points);
|
||||
|
||||
server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
|
||||
vertx.createHttpServer()
|
||||
.requestHandler(r)
|
||||
.listen(instanceArgs.port);
|
||||
}
|
||||
|
||||
protected Router createRouter(VPTree tree, List<String> labels, INDArray points, BuiltInComponents builtInComponents){
|
||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
||||
//return the host information for a given id
|
||||
routingDsl.POST("/knn").routingTo(request -> {
|
||||
private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
|
||||
|
||||
r.post("/knn").handler(rc -> {
|
||||
try {
|
||||
NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
|
||||
String json = rc.getBodyAsJson().encode();
|
||||
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
|
||||
|
||||
NearestNeighbor nearestNeighbor =
|
||||
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
||||
|
||||
if (record == null)
|
||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
||||
if (record == null) {
|
||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||
return;
|
||||
}
|
||||
|
||||
NearestNeighborsResults results =
|
||||
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
||||
|
||||
|
||||
return ok(Json.toJson(results));
|
||||
NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
||||
|
||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(JsonMappers.getMapper().writeValueAsString(results));
|
||||
return;
|
||||
} catch (Throwable e) {
|
||||
log.error("Error in POST /knn",e);
|
||||
e.printStackTrace();
|
||||
return internalServerError(e.getMessage());
|
||||
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||
.end("Error parsing request - " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/knnnew").routingTo(request -> {
|
||||
r.post("/knnnew").handler(rc -> {
|
||||
try {
|
||||
Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
|
||||
if (record == null)
|
||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
||||
String json = rc.getBodyAsJson().encode();
|
||||
Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
|
||||
if (record == null) {
|
||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||
return;
|
||||
}
|
||||
|
||||
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
|
||||
List<DataPoint> results;
|
||||
|
@ -214,9 +244,10 @@ public class NearestNeighborsServer {
|
|||
}
|
||||
|
||||
if (results.size() != distances.size()) {
|
||||
return internalServerError(
|
||||
String.format("results.size == %d != %d == distances.size",
|
||||
results.size(), distances.size()));
|
||||
rc.response()
|
||||
.setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||
.end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
|
||||
return;
|
||||
}
|
||||
|
||||
List<NearestNeighborsResult> nnResult = new ArrayList<>();
|
||||
|
@ -228,30 +259,29 @@ public class NearestNeighborsServer {
|
|||
}
|
||||
|
||||
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
|
||||
return ok(Json.toJson(results2));
|
||||
|
||||
String j = JsonMappers.getMapper().writeValueAsString(results2);
|
||||
rc.response()
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(j);
|
||||
} catch (Throwable e) {
|
||||
log.error("Error in POST /knnnew",e);
|
||||
e.printStackTrace();
|
||||
return internalServerError(e.getMessage());
|
||||
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||
.end("Error parsing request - " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
return routingDsl.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the server
|
||||
*/
|
||||
public void stop() {
|
||||
if (server != null) {
|
||||
log.info("Attempting to stop server");
|
||||
server.stop();
|
||||
}
|
||||
public void stop() throws Exception {
|
||||
super.stop();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
new NearestNeighborsServer().runMain(args);
|
||||
runMain(args);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
|
|||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
||||
public void testNearestNeighbor() {
|
||||
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
||||
INDArray arr = Nd4j.create(data);
|
||||
|
@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
|
|||
File writeToTmp = testDir.newFile();
|
||||
writeToTmp.deleteOnExit();
|
||||
BinarySerde.writeArrayToDisk(rand, writeToTmp);
|
||||
NearestNeighborsServer server = new NearestNeighborsServer();
|
||||
server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
||||
NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
||||
String.valueOf(localPort));
|
||||
|
||||
Thread.sleep(3000);
|
||||
|
||||
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
|
||||
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
|
||||
assertEquals(5, result.getResults().size());
|
||||
server.stop();
|
||||
NearestNeighborsServer.getInstance().stop();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2019 Konduit K.K.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<configuration>
|
||||
|
||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||
<file>logs/application.log</file>
|
||||
<encoder>
|
||||
<pattern>%date - [%level] - from %logger in %thread
|
||||
%n%message%n%xException%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<logger name="org.deeplearning4j" level="INFO" />
|
||||
<logger name="org.datavec" level="INFO" />
|
||||
<logger name="org.nd4j" level="INFO" />
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
</configuration>
|
|
@ -1,5 +1,6 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~ Copyright (c) 2019 Konduit K.K.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
|
|
Loading…
Reference in New Issue