[WIP] DL4J nearestneighbors-sever: Play ->Vertx (#79)
* Switch Nearest neighbors server implementation from Play to Vertx Signed-off-by: AlexDBlack <blacka101@gmail.com> * No more scala version suffix for nearest neighbor server Signed-off-by: AlexDBlack <blacka101@gmail.com> * logback.xml fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Header tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
7f90930e7a
commit
0e3fcdc24d
|
@ -1,5 +1,6 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~ Copyright (c) 2019 Konduit K.K.
|
||||||
~
|
~
|
||||||
~ This program and the accompanying materials are made available under the
|
~ This program and the accompanying materials are made available under the
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -23,16 +24,11 @@
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<artifactId>deeplearning4j-nearestneighbor-server_2.11</artifactId>
|
<artifactId>deeplearning4j-nearestneighbor-server</artifactId>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
<name>deeplearning4j-nearestneighbor-server</name>
|
<name>deeplearning4j-nearestneighbor-server</name>
|
||||||
|
|
||||||
<properties>
|
|
||||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
|
||||||
<scala.version>2.11.12</scala.version>
|
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
|
||||||
</properties>
|
|
||||||
<build>
|
<build>
|
||||||
<pluginManagement>
|
<pluginManagement>
|
||||||
<plugins>
|
<plugins>
|
||||||
|
@ -73,29 +69,17 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>io.vertx</groupId>
|
||||||
<artifactId>play-java_2.11</artifactId>
|
<artifactId>vertx-core</artifactId>
|
||||||
<version>${playframework.version}</version>
|
<version>${vertx.version}</version>
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
|
||||||
<artifactId>jsr305</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.apache.tomcat</groupId>
|
|
||||||
<artifactId>tomcat-servlet-api</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>net.jodah</groupId>
|
|
||||||
<artifactId>typetools</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>net.jodah</groupId>
|
<groupId>io.vertx</groupId>
|
||||||
<artifactId>typetools</artifactId>
|
<artifactId>vertx-web</artifactId>
|
||||||
<version>${jodah.typetools.version}</version>
|
<version>${vertx.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.mashape.unirest</groupId>
|
<groupId>com.mashape.unirest</groupId>
|
||||||
<artifactId>unirest-java</artifactId>
|
<artifactId>unirest-java</artifactId>
|
||||||
|
@ -108,25 +92,16 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-json_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-server_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.beust</groupId>
|
<groupId>com.beust</groupId>
|
||||||
<artifactId>jcommander</artifactId>
|
<artifactId>jcommander</artifactId>
|
||||||
<version>${jcommander.version}</version>
|
<version>${jcommander.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>ch.qos.logback</groupId>
|
||||||
<artifactId>play-netty-server_2.11</artifactId>
|
<artifactId>logback-classic</artifactId>
|
||||||
<version>${playframework.version}</version>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -19,6 +20,11 @@ package org.deeplearning4j.nearestneighbor.server;
|
||||||
import com.beust.jcommander.JCommander;
|
import com.beust.jcommander.JCommander;
|
||||||
import com.beust.jcommander.Parameter;
|
import com.beust.jcommander.Parameter;
|
||||||
import com.beust.jcommander.ParameterException;
|
import com.beust.jcommander.ParameterException;
|
||||||
|
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
|
import io.vertx.core.AbstractVerticle;
|
||||||
|
import io.vertx.core.Vertx;
|
||||||
|
import io.vertx.ext.web.Router;
|
||||||
|
import io.vertx.ext.web.handler.BodyHandler;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||||
|
@ -26,6 +32,7 @@ import org.deeplearning4j.clustering.vptree.VPTree;
|
||||||
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nearestneighbor.model.*;
|
import org.deeplearning4j.nearestneighbor.model.*;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
@ -33,19 +40,10 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
import org.nd4j.serde.base64.Nd4jBase64;
|
||||||
import org.nd4j.serde.binary.BinarySerde;
|
import org.nd4j.serde.binary.BinarySerde;
|
||||||
import play.BuiltInComponents;
|
|
||||||
import play.Mode;
|
|
||||||
import play.libs.Json;
|
|
||||||
import play.routing.Router;
|
|
||||||
import play.routing.RoutingDsl;
|
|
||||||
import play.server.Server;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static play.mvc.Controller.request;
|
|
||||||
import static play.mvc.Results.*;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A rest server for using an
|
* A rest server for using an
|
||||||
* {@link VPTree} based on loading an ndarray containing
|
* {@link VPTree} based on loading an ndarray containing
|
||||||
|
@ -57,22 +55,33 @@ import static play.mvc.Results.*;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class NearestNeighborsServer {
|
public class NearestNeighborsServer extends AbstractVerticle {
|
||||||
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
|
||||||
private String ndarrayPath = null;
|
|
||||||
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
|
||||||
private String labelsPath = null;
|
|
||||||
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
|
||||||
private int port = 9000;
|
|
||||||
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
|
||||||
private String similarityFunction = "euclidean";
|
|
||||||
@Parameter(names = {"--invert"}, arity = 1)
|
|
||||||
private boolean invert = false;
|
|
||||||
|
|
||||||
private Server server;
|
private static class RunArgs {
|
||||||
|
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
||||||
|
private String ndarrayPath = null;
|
||||||
|
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
||||||
|
private String labelsPath = null;
|
||||||
|
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
||||||
|
private int port = 9000;
|
||||||
|
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
||||||
|
private String similarityFunction = "euclidean";
|
||||||
|
@Parameter(names = {"--invert"}, arity = 1)
|
||||||
|
private boolean invert = false;
|
||||||
|
}
|
||||||
|
|
||||||
public void runMain(String... args) throws Exception {
|
private static RunArgs instanceArgs;
|
||||||
JCommander jcmdr = new JCommander(this);
|
private static NearestNeighborsServer instance;
|
||||||
|
|
||||||
|
public NearestNeighborsServer(){ }
|
||||||
|
|
||||||
|
public static NearestNeighborsServer getInstance(){
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void runMain(String... args) {
|
||||||
|
RunArgs r = new RunArgs();
|
||||||
|
JCommander jcmdr = new JCommander(r);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
jcmdr.parse(args);
|
jcmdr.parse(args);
|
||||||
|
@ -84,7 +93,7 @@ public class NearestNeighborsServer {
|
||||||
|
|
||||||
//User provides invalid input -> print the usage info
|
//User provides invalid input -> print the usage info
|
||||||
jcmdr.usage();
|
jcmdr.usage();
|
||||||
if (ndarrayPath == null)
|
if (r.ndarrayPath == null)
|
||||||
log.error("Json path parameter is missing (null)");
|
log.error("Json path parameter is missing (null)");
|
||||||
try {
|
try {
|
||||||
Thread.sleep(500);
|
Thread.sleep(500);
|
||||||
|
@ -93,16 +102,20 @@ public class NearestNeighborsServer {
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
instanceArgs = r;
|
||||||
try {
|
try {
|
||||||
runHelper();
|
Vertx vertx = Vertx.vertx();
|
||||||
|
vertx.deployVerticle(NearestNeighborsServer.class.getName());
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
log.error("Error in NearestNeighboursServer run method",t);
|
log.error("Error in NearestNeighboursServer run method",t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void runHelper() throws Exception {
|
@Override
|
||||||
|
public void start() throws Exception {
|
||||||
|
instance = this;
|
||||||
|
|
||||||
String[] pathArr = ndarrayPath.split(",");
|
String[] pathArr = instanceArgs.ndarrayPath.split(",");
|
||||||
//INDArray[] pointsArr = new INDArray[pathArr.length];
|
//INDArray[] pointsArr = new INDArray[pathArr.length];
|
||||||
// first of all we reading shapes of saved eariler files
|
// first of all we reading shapes of saved eariler files
|
||||||
int rows = 0;
|
int rows = 0;
|
||||||
|
@ -111,7 +124,7 @@ public class NearestNeighborsServer {
|
||||||
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
|
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
|
||||||
|
|
||||||
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
|
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
|
||||||
Shape.size(shape, 1));
|
Shape.size(shape, 1));
|
||||||
|
|
||||||
if (Shape.rank(shape) != 2)
|
if (Shape.rank(shape) != 2)
|
||||||
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
|
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
|
||||||
|
@ -122,12 +135,12 @@ public class NearestNeighborsServer {
|
||||||
cols = Shape.size(shape, 1);
|
cols = Shape.size(shape, 1);
|
||||||
else if (cols != Shape.size(shape, 1))
|
else if (cols != Shape.size(shape, 1))
|
||||||
throw new DL4JInvalidInputException(
|
throw new DL4JInvalidInputException(
|
||||||
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<String> labels = new ArrayList<>();
|
final List<String> labels = new ArrayList<>();
|
||||||
if (labelsPath != null) {
|
if (instanceArgs.labelsPath != null) {
|
||||||
String[] labelsPathArr = labelsPath.split(",");
|
String[] labelsPathArr = instanceArgs.labelsPath.split(",");
|
||||||
for (int i = 0; i < labelsPathArr.length; i++) {
|
for (int i = 0; i < labelsPathArr.length; i++) {
|
||||||
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
|
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
|
||||||
}
|
}
|
||||||
|
@ -149,7 +162,7 @@ public class NearestNeighborsServer {
|
||||||
System.gc();
|
System.gc();
|
||||||
}
|
}
|
||||||
|
|
||||||
VPTree tree = new VPTree(points, similarityFunction, invert);
|
VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
|
||||||
|
|
||||||
//Set play secret key, if required
|
//Set play secret key, if required
|
||||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||||
|
@ -163,40 +176,57 @@ public class NearestNeighborsServer {
|
||||||
System.setProperty("play.crypto.secret", base64);
|
System.setProperty("play.crypto.secret", base64);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Router r = Router.router(vertx);
|
||||||
|
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
|
||||||
|
createRoutes(r, labels, tree, points);
|
||||||
|
|
||||||
server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
|
vertx.createHttpServer()
|
||||||
|
.requestHandler(r)
|
||||||
|
.listen(instanceArgs.port);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Router createRouter(VPTree tree, List<String> labels, INDArray points, BuiltInComponents builtInComponents){
|
private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
|
||||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
|
||||||
//return the host information for a given id
|
r.post("/knn").handler(rc -> {
|
||||||
routingDsl.POST("/knn").routingTo(request -> {
|
|
||||||
try {
|
try {
|
||||||
NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
|
String json = rc.getBodyAsJson().encode();
|
||||||
|
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
|
||||||
|
|
||||||
NearestNeighbor nearestNeighbor =
|
NearestNeighbor nearestNeighbor =
|
||||||
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
||||||
|
|
||||||
if (record == null)
|
if (record == null) {
|
||||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
NearestNeighborsResults results =
|
NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
||||||
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
|
||||||
|
|
||||||
|
|
||||||
return ok(Json.toJson(results));
|
|
||||||
|
|
||||||
|
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(JsonMappers.getMapper().writeValueAsString(results));
|
||||||
|
return;
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
log.error("Error in POST /knn",e);
|
log.error("Error in POST /knn",e);
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError(e.getMessage());
|
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||||
|
.end("Error parsing request - " + e.getMessage());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
routingDsl.POST("/knnnew").routingTo(request -> {
|
r.post("/knnnew").handler(rc -> {
|
||||||
try {
|
try {
|
||||||
Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
|
String json = rc.getBodyAsJson().encode();
|
||||||
if (record == null)
|
Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
|
||||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
if (record == null) {
|
||||||
|
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
|
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
|
||||||
List<DataPoint> results;
|
List<DataPoint> results;
|
||||||
|
@ -214,9 +244,10 @@ public class NearestNeighborsServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (results.size() != distances.size()) {
|
if (results.size() != distances.size()) {
|
||||||
return internalServerError(
|
rc.response()
|
||||||
String.format("results.size == %d != %d == distances.size",
|
.setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||||
results.size(), distances.size()));
|
.end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<NearestNeighborsResult> nnResult = new ArrayList<>();
|
List<NearestNeighborsResult> nnResult = new ArrayList<>();
|
||||||
|
@ -228,30 +259,29 @@ public class NearestNeighborsServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
|
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
|
||||||
return ok(Json.toJson(results2));
|
String j = JsonMappers.getMapper().writeValueAsString(results2);
|
||||||
|
rc.response()
|
||||||
|
.putHeader("content-type", "application/json")
|
||||||
|
.end(j);
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
log.error("Error in POST /knnnew",e);
|
log.error("Error in POST /knnnew",e);
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError(e.getMessage());
|
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||||
|
.end("Error parsing request - " + e.getMessage());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return routingDsl.build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stop the server
|
* Stop the server
|
||||||
*/
|
*/
|
||||||
public void stop() {
|
public void stop() throws Exception {
|
||||||
if (server != null) {
|
super.stop();
|
||||||
log.info("Attempting to stop server");
|
|
||||||
server.stop();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
new NearestNeighborsServer().runMain(args);
|
runMain(args);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -50,7 +51,6 @@ public class NearestNeighborTest extends BaseDL4JTest {
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
//@Ignore("AB 2019/05/21 - Failing - Issue #7657")
|
|
||||||
public void testNearestNeighbor() {
|
public void testNearestNeighbor() {
|
||||||
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
||||||
INDArray arr = Nd4j.create(data);
|
INDArray arr = Nd4j.create(data);
|
||||||
|
@ -119,14 +119,15 @@ public class NearestNeighborTest extends BaseDL4JTest {
|
||||||
File writeToTmp = testDir.newFile();
|
File writeToTmp = testDir.newFile();
|
||||||
writeToTmp.deleteOnExit();
|
writeToTmp.deleteOnExit();
|
||||||
BinarySerde.writeArrayToDisk(rand, writeToTmp);
|
BinarySerde.writeArrayToDisk(rand, writeToTmp);
|
||||||
NearestNeighborsServer server = new NearestNeighborsServer();
|
NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
||||||
server.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
|
||||||
String.valueOf(localPort));
|
String.valueOf(localPort));
|
||||||
|
|
||||||
|
Thread.sleep(3000);
|
||||||
|
|
||||||
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
|
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
|
||||||
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
|
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
|
||||||
assertEquals(5, result.getResults().size());
|
assertEquals(5, result.getResults().size());
|
||||||
server.stop();
|
NearestNeighborsServer.getInstance().stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
~ Copyright (c) 2019 Konduit K.K.
|
||||||
|
~
|
||||||
|
~ This program and the accompanying materials are made available under the
|
||||||
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
~
|
||||||
|
~ Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
~ License for the specific language governing permissions and limitations
|
||||||
|
~ under the License.
|
||||||
|
~
|
||||||
|
~ SPDX-License-Identifier: Apache-2.0
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||||
|
|
||||||
|
<configuration>
|
||||||
|
|
||||||
|
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||||
|
<file>logs/application.log</file>
|
||||||
|
<encoder>
|
||||||
|
<pattern>%date - [%level] - from %logger in %thread
|
||||||
|
%n%message%n%xException%n</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||||
|
<encoder>
|
||||||
|
<pattern> %logger{15} - %message%n%xException{5}
|
||||||
|
</pattern>
|
||||||
|
</encoder>
|
||||||
|
</appender>
|
||||||
|
|
||||||
|
<logger name="org.deeplearning4j" level="INFO" />
|
||||||
|
<logger name="org.datavec" level="INFO" />
|
||||||
|
<logger name="org.nd4j" level="INFO" />
|
||||||
|
|
||||||
|
<root level="ERROR">
|
||||||
|
<appender-ref ref="STDOUT" />
|
||||||
|
<appender-ref ref="FILE" />
|
||||||
|
</root>
|
||||||
|
</configuration>
|
|
@ -1,5 +1,6 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
~ Copyright (c) 2019 Konduit K.K.
|
||||||
~
|
~
|
||||||
~ This program and the accompanying materials are made available under the
|
~ This program and the accompanying materials are made available under the
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
~ terms of the Apache License, Version 2.0 which is available at
|
||||||
|
|
Loading…
Reference in New Issue