[WIP] Remote inference (#96)
* fix pad javadoc and @see links. (#72) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * [WIP] More fixes (#73) * special tests for ConstantTadHelper/ConstantShapeHelper Signed-off-by: raver119 <raver119@gmail.com> * release methods for data buffers Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com> * delete temporary TadPack C++/Java side (#74) Signed-off-by: raver119 <raver119@gmail.com> * Zoo model TF import test updates (#75) * argLine fix, update compression_gru comment * updated comment for xception * undid but commented argLine change * updated xlnet comment * copyright headers * - new NDArray methods like()/ulike() (#77) - fix for depthwise_conv2d_bp + special test Signed-off-by: raver119 <raver119@gmail.com> * upsampling2d fix CUDA Signed-off-by: raver119 <raver119@gmail.com> * DL4J trace logging (#79) * MLN/CG trace logging for debugging Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tiny tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * strided_slice_bp shape fn leak fix Signed-off-by: raver119 <raver119@gmail.com> * SameDiff fixes and naming (#78) * remove SDVariable inplace methods * import methods * npe fix in OpVal * removed SameDiff inplace ops from tests * Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything * quick fixes * javadoc * SDVariable eval with placeholders * use regex match * better matching * fix javadoc. (#76) * fix javadoc. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace most @see with @link s. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * 4 additional tests Signed-off-by: raver119 <raver119@gmail.com> * Various DL4J/ND4J fixes (#81) * #7954 Force refresh of UI when switching tabs on overview page Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8017 Concurrent modification exception (synchronize) fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8033 Don't initialize updater in middle of writing memory crash dump Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8208 Fix shape checks for ND4J int[] creator methods Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6385 #7992 Keras import naming fixes + cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8016 Upsampling3D - add NDHWC format support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Refactor NativeOps.h to export C functions * Actually export functions from NativeOps.h * Adapt the Java wrappers in ND4J generated with JavaCPP * Create C wrappers for some of the C++ classes currently used by ND4J * remove duplicate code in createBufferDetached. (#83) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Keras model import - updater lr fix (#84) * Keras model import - updater lr fix Signed-off-by: eraly <susan.eraly@gmail.com> * Keras model import - updater lr fix, cleanup Signed-off-by: eraly <susan.eraly@gmail.com> * Fix functions of OpaqueVariablesSet * SameDiff Convolution Config validation, better output methods (#82) * Conv Config validation & tests Signed-off-by: Ryan Nett <rnett@skymind.io> * stackOutputs utility method Signed-off-by: Ryan Nett <rnett@skymind.io> * use constructor for validation, support negative kernel sizes (infered from weights) Signed-off-by: Ryan Nett <rnett@skymind.io> * better output methods Signed-off-by: Ryan Nett <rnett@skymind.io> * move output to be with fit and evaluate Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * refactor duplicate code from pad methods. (#86) * refactor duplicate code from pad methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * replace switch with if. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Various ND4J/DL4J fixes and improvements (#87) * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reshape and reallocate - small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6488 ElementWiseVertex broadcast support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Constructors and broadcast supported it Transforms.max/min Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8054 ElementWiseVertex now supports broadcast inputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8057 Nd4j.create overload dtype fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7551 ND4J Shape validation fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] Numpy boolean import (#91) * numpy bool type Signed-off-by: raver119 <raver119@gmail.com> * numpy bool java side Signed-off-by: raver119 <raver119@gmail.com> * remove create method with unused parameter. (#89) * remove create method with unused parameter. * removed more unused methods. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * removing more unused code. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * last removal of unused code. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * remove createSparse methods. (#92) Signed-off-by: Robert Altena <Rob@Ra-ai.com> * Various ND4J/DL4J fixes (#90) * Deprecate Old*Op instances Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8063 #8054 Broadcast exceptions + cleanup inplace ops Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove bad test condition Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7993 Fix shape function issue in crop_and_resize op Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J SameDiff lambda layer fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8029 Fix for pnorm backprop math Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8038 Fix Op profiler NaN/Inf triggering + add tests (#93) Signed-off-by: AlexDBlack <blacka101@gmail.com> * createUninitializedDetached refactoring. (#94) * wip * update interface, add null implementations. * Breaking one test in a weird way. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * createUninitializedDetached refactored. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * cuda build fix for issues introduced by recent refactoring Signed-off-by: raver119 <raver119@gmail.com> * initial commit Signed-off-by: raver119 <raver119@gmail.com> * deps tweaks Signed-off-by: raver119 <raver119@gmail.com> * initial prototype Signed-off-by: raver119 <raver119@gmail.com> * modules reorganized Signed-off-by: raver119 <raver119@gmail.com> * gprc module moved to nd4j-remote as well Signed-off-by: raver119 <raver119@gmail.com> * gprc module moved to nd4j-remote as well Signed-off-by: raver119 <raver119@gmail.com> * serving prototype Signed-off-by: raver119 <raver119@gmail.com> * serving prototype Signed-off-by: raver119 <raver119@gmail.com> * serving prototype Signed-off-by: raver119 <raver119@gmail.com> * serving prototype Signed-off-by: raver119 <raver119@gmail.com> * [WIP] More of CUDA (#95) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * Implementation of hashcode cuda helper. Working edition. * Fixed parallel test input arangements. * Fixed tests for hashcode op. * Fixed shape calculation for image:crop_and_resize op and test. * NativeOps tests. Initial test suite. * Added tests for indexReduce methods. * Added test on execBroadcast with NDArray as dimensions. * Added test on execBroadcastBool with NDArray as dimensions. * Added tests on execPairwiseTransform and execPairwiseTransofrmBool. * Added tests for execReduce with scalar results. * Added reduce tests for non-empty dims array. * Added tests for reduce3. * Added tests for execScalar. * Added tests for execSummaryStats. * - provide cpu/cuda code for batch_to_space - testing it Signed-off-by: Yurii <yurii@skymind.io> * - remove old test for batch_to_space (had wrong format and numbers were not checked) Signed-off-by: Yurii <yurii@skymind.io> * Fixed complilation errors with test. * Added test for execTransformFloat. * Added test for execTransformSame. * Added test for execTransformBool. * Added test for execTransformStrict. * Added tests for execScalar/execScalarBool with TADs. * Added test for flatten. * - provide cpu/cuda code for space_to_Batch operaion Signed-off-by: Yurii <yurii@skymind.io> * Added test for concat. * comment unnecessary stuff in s_t_b Signed-off-by: Yurii <yurii@skymind.io> * Added test for specialConcat. * Added tests for memcpy/set routines. * Fixed pullRow cuda test. * Added pullRow test. * Added average test. * - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...) Signed-off-by: Yurii <yurii@skymind.io> * - debugging and fixing cuda tests in JavaInteropTests file Signed-off-by: Yurii <yurii@skymind.io> * - correct some tests Signed-off-by: Yurii <yurii@skymind.io> * Added test for shuffle. * Fixed ops declarations. * Restored omp and added shuffle test. * Added convertTypes test. * Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps. * Added sort tests. * Added tests for execCustomOp. * - further debuging and fixing tests terminated with crash Signed-off-by: Yurii <yurii@skymind.io> * Added tests for calculateOutputShapes. * Addded Benchmarks test. * Commented benchmark tests. * change assertion Signed-off-by: raver119 <raver119@gmail.com> * Added tests for apply_sgd op. Added cpu helper for that op. * Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps. * Added test for assign broadcastable. * Added tests for assign_bp op. * Added tests for axpy op. * - assign/execScalar/execTransformAny signature change - minor test fix Signed-off-by: raver119 <raver119@gmail.com> * Fixed axpy op. * meh Signed-off-by: raver119 <raver119@gmail.com> * - fix tests for nativeOps::concat Signed-off-by: Yurii <yurii@skymind.io> * sequential transform/scalar Signed-off-by: raver119 <raver119@gmail.com> * allow nested parallelism Signed-off-by: raver119 <raver119@gmail.com> * assign_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * block setRNG fix Signed-off-by: raver119 <raver119@gmail.com> * enable parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * enable nested parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * Added cuda implementation for row_count helper. * Added implementation for tnse gains op helper. * - take into account possible situations when input arrays are empty in reduce_ cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces. * Added kernel for tsne/symmetrized op heleper. * Implementation of tsne/symmetrized op cuda helper. Working edition. * Eliminated waste printfs. * Added test for broadcastgradientargs op. * host-only fallback for empty reduce float Signed-off-by: raver119 <raver119@gmail.com> * - some tests fixes Signed-off-by: Yurii <yurii@skymind.io> * - correct the rest of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * - further correction of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * Added test for Cbow op. Also added cuda implementation for cbow helpers. * - improve code of stack operation for scalar case Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda kernel for gatherND operation Signed-off-by: Yurii <yurii@skymind.io> * Implementation of cbow helpers with cuda kernels. * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * - further correction of cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implementatation of cbow op helper with cuda kernels. Working edition. * Skip random testing for cudablas case. * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for ELU and ELU_BP ops. * Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops. * Added tests for neq_scalar. * Added test for noop. * - further work on clipbynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * - get rid of concat op call, use instead direct concat helper call Signed-off-by: Yurii <yurii@skymind.io> * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for lrelu and lrelu_bp. * Added tests for selu and selu_bp. * Fixed lrelu derivative helpers. * - some corrections in lstm Signed-off-by: Yurii <yurii@skymind.io> * operator * result shape fix Signed-off-by: raver119 <raver119@gmail.com> * - correct typo in lstmCell Signed-off-by: Yurii <yurii@skymind.io> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * CUDA inverse broadcast bool fix Signed-off-by: raver119 <raver119@gmail.com> * disable MMAP test for CUDA Signed-off-by: raver119 <raver119@gmail.com> * BooleanOp syncToDevice Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * additional data types for im2col/col2im Signed-off-by: raver119 <raver119@gmail.com> * Added test for firas_sparse op. * one more RandomBuffer test excluded Signed-off-by: raver119 <raver119@gmail.com> * Added tests for flatten op. * Added test for Floor op. * bunch of tests fixed Signed-off-by: raver119 <raver119@gmail.com> * mmulDot tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Implemented floordiv_bp op and tests. * Fixed scalar case with cuda implementation for bds. * - work on cuda kernel for clip_by_norm backprop op is completed Signed-off-by: Yurii <yurii@skymind.io> * Eliminate cbow crach. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Eliminated abortion with batched nlp test. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Fixed shared flag initializing. * disabled bunch of cpu workspaces tests Signed-off-by: raver119 <raver119@gmail.com> * scalar operators fix: missing registerSpecialUse call Signed-off-by: raver119 <raver119@gmail.com> * Fixed logdet for cuda and tests. * - correct clipBynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * Fixed crop_and_resize shape datatype. * - correct some mmul tests Signed-off-by: Yurii <yurii@skymind.io> * build fix Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI Signed-off-by: raver119 <raver119@gmail.com> * exclude two methods for JNI (#97) Signed-off-by: raver119 <raver119@gmail.com> * temporary stack fix Signed-off-by: raver119 <raver119@gmail.com> * downgrade jetty to latest stable version Signed-off-by: raver119 <raver119@gmail.com> * test and profiles Signed-off-by: raver119 <raver119@gmail.com> * Servlet skeleton * one test case Signed-off-by: raver119 <raver119@gmail.com> * one test case Signed-off-by: raver119 <raver119@gmail.com> * compilation fix Signed-off-by: raver119 <raver119@gmail.com> * draft improvements Signed-off-by: raver119 <raver119@gmail.com> * draft improvements Signed-off-by: raver119 <raver119@gmail.com> * proof of concept works Signed-off-by: raver119 <raver119@gmail.com> * proof of concept works Signed-off-by: raver119 <raver119@gmail.com> * Servlet Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * logging + simple timing Signed-off-by: raver119 <raver119@gmail.com> * Content type fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Profile required Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Servlet tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Post test Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests added: Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Minor tweaks Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Constants used Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Check content type Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Some tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Errors checking Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Constraints and tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Minor tweaks Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Dl4j servlet skeleton Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Moving class to dl4j Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Builder extended Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * initial dl4j commit Signed-off-by: raver119 <raver119@gmail.com> * unirest version change Signed-off-by: raver119 <raver119@gmail.com> * temp fallback Signed-off-by: raver119 <raver119@gmail.com> * Reverted unirest version Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Reverted unirest version Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * revert back unirest version change Signed-off-by: raver119 <raver119@gmail.com> * revert unirest change Signed-off-by: raver119 <raver119@gmail.com> * some additional checks in builder Signed-off-by: raver119 <raver119@gmail.com> * few more fields Signed-off-by: raver119 <raver119@gmail.com> * Test added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * lombok Signed-off-by: raver119 <raver119@gmail.com> * Tests added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * deps Signed-off-by: raver119 <raver119@gmail.com> * profiles re-introduced Signed-off-by: raver119 <raver119@gmail.com> * Added tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Model servlet Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * builders Signed-off-by: raver119 <raver119@gmail.com> * builders Signed-off-by: raver119 <raver119@gmail.com> * Servlet skeleton Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Servlet tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * builders Signed-off-by: raver119 <raver119@gmail.com> * get rid of old class Signed-off-by: raver119 <raver119@gmail.com> * use PI for inference Signed-off-by: raver119 <raver119@gmail.com> * superbuilder Signed-off-by: raver119 <raver119@gmail.com> * get back builder Signed-off-by: raver119 <raver119@gmail.com> * Servlet builder Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * PI setup Signed-off-by: raver119 <raver119@gmail.com> * get rid of superbuilder Signed-off-by: raver119 <raver119@gmail.com> * SameDiffServlet inheritance constructor Signed-off-by: raver119 <raver119@gmail.com> * dl4jservlet attached to samediffservlet Signed-off-by: raver119 <raver119@gmail.com> * builder types fix Signed-off-by: raver119 <raver119@gmail.com> * dummy model Signed-off-by: raver119 <raver119@gmail.com> * single out Signed-off-by: raver119 <raver119@gmail.com> * loss Signed-off-by: raver119 <raver119@gmail.com> * Tests added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * missed builder type Signed-off-by: raver119 <raver119@gmail.com> * working serving example Signed-off-by: raver119 <raver119@gmail.com> * sd model fix Signed-off-by: raver119 <raver119@gmail.com> * fix unirest version Signed-off-by: raver119 <raver119@gmail.com> * More tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests added: Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Minor tests fixes Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Build fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Ser/deser added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * one more unirest fix Signed-off-by: raver119 <raver119@gmail.com> * Custom serializers Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Tests disabled Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * revert back unirest version change Signed-off-by: raver119 <raver119@gmail.com> * update Signed-off-by: raver119 <raver119@gmail.com> * some default fields values Signed-off-by: raver119 <raver119@gmail.com> * some comments/javadoc Signed-off-by: raver119 <raver119@gmail.com> * - move serde impls to client module - get rid of INDArray serde for now Signed-off-by: raver119 <raver119@gmail.com> * jackson-based serde for float[], double[] and String Signed-off-by: raver119 <raver119@gmail.com> * more of basic ser/de + tests Signed-off-by: raver119 <raver119@gmail.com> * minor api changes Signed-off-by: raver119 <raver119@gmail.com> * change imports/signatures Signed-off-by: raver119 <raver119@gmail.com> * Optional parralel inference Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Insert pause between tests as workaround for unavailable port issue Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * few unused imports removed Signed-off-by: raver119 <raver119@gmail.com> * Models usage Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Models usage Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * - InputAdapter + OutputAdapter = InferenceAdapter - JsonModelServer now allows separate configuration of InputAdapter and OutputAdapter Signed-off-by: raver119 <raver119@gmail.com> * unused import Signed-off-by: raver119 <raver119@gmail.com> * input adapter.. Signed-off-by: raver119 <raver119@gmail.com> * minor signature change Signed-off-by: raver119 <raver119@gmail.com> * few more signatures updated Signed-off-by: raver119 <raver119@gmail.com> * input/output adapter Signed-off-by: raver119 <raver119@gmail.com> * Tests added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * javadocs added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Test fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * minor polishing Signed-off-by: raver119 <raver119@gmail.com> * more of javadoc Signed-off-by: raver119 <raver119@gmail.com> * signature change Signed-off-by: raver119 <raver119@gmail.com>master
parent
b10ab239c0
commit
ec847e034b
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform.client;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import com.mashape.unirest.http.exceptions.UnirestException;
|
||||
|
|
|
@ -51,11 +51,13 @@
|
|||
<artifactId>datavec-spark-inference-model</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark_2.11</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
|
@ -67,61 +69,73 @@
|
|||
<artifactId>akka-cluster_2.11</artifactId>
|
||||
<version>${akka.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
<version>${jodatime.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>${commons-lang3.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.hibernate</groupId>
|
||||
<artifactId>hibernate-validator</artifactId>
|
||||
<version>${hibernate.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.yaml</groupId>
|
||||
<artifactId>snakeyaml</artifactId>
|
||||
<version>${snakeyaml.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jdk8</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-java_2.11</artifactId>
|
||||
|
@ -137,39 +151,44 @@
|
|||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>net.jodah</groupId>
|
||||
<artifactId>typetools</artifactId>
|
||||
<version>${jodah.typetools.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-json_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-server_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-netty-server_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<version>${unirest.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
|
|
|
@ -52,6 +52,7 @@ public class CSVSparkTransformServerNoJsonTest {
|
|||
public static void before() throws Exception {
|
||||
server = new CSVSparkTransformServer();
|
||||
FileUtils.write(fileSave, transformProcess.toJson());
|
||||
|
||||
// Only one time
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
|
@ -73,6 +74,7 @@ public class CSVSparkTransformServerNoJsonTest {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.runMain(new String[] {"-dp", "9050"});
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
|
@ -49,6 +50,7 @@ public class CSVSparkTransformServerTest {
|
|||
server = new CSVSparkTransformServer();
|
||||
FileUtils.write(fileSave, transformProcess.toJson());
|
||||
// Only one time
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
@ -69,6 +71,7 @@ public class CSVSparkTransformServerTest {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
|
|
|
@ -19,10 +19,10 @@ package org.deeplearning4j.nearestneighbor.client;
|
|||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import com.mashape.unirest.request.HttpRequest;
|
||||
import com.mashape.unirest.request.HttpRequestWithBody;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nearestneighbor.model.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
@ -51,6 +51,7 @@ public class NearestNeighborsClient {
|
|||
|
||||
static {
|
||||
// Only one time
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
@ -89,7 +90,7 @@ public class NearestNeighborsClient {
|
|||
NearestNeighborRequest request = new NearestNeighborRequest();
|
||||
request.setInputIndex(index);
|
||||
request.setK(k);
|
||||
HttpRequestWithBody req = Unirest.post(url + "/knn");
|
||||
val req = Unirest.post(url + "/knn");
|
||||
req.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(request);
|
||||
addAuthHeader(req);
|
||||
|
@ -112,7 +113,7 @@ public class NearestNeighborsClient {
|
|||
Base64NDArrayBody base64NDArrayBody =
|
||||
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
|
||||
|
||||
HttpRequestWithBody req = Unirest.post(url + "/knnnew");
|
||||
val req = Unirest.post(url + "/knnnew");
|
||||
req.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(base64NDArrayBody);
|
||||
addAuthHeader(req);
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.deeplearning4j.models.word2vec;
|
|||
import com.google.gson.JsonArray;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
import jdk.nashorn.internal.objects.annotations.Property;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.deeplearning4j.nn.adapters;
|
||||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.deeplearning4j.nn.adapters;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.nn.api;
|
||||
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
|
|
|
@ -24,6 +24,7 @@ import lombok.val;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
|
|
|
@ -25,6 +25,7 @@ import lombok.val;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
|
||||
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
||||
import org.deeplearning4j.eval.RegressionEvaluation;
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-remote</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>deeplearning4j-json-server</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<name>deeplearning4j-json-server</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>${lombok.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-json-client</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-json-server</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-parallel-wrapper</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-core</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<activation>
|
||||
<activeByDefault>true</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -0,0 +1,205 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.parallelism.ParallelInference;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.remote.serving.SameDiffServlet;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author astoyakin
|
||||
*/
|
||||
@Slf4j
|
||||
@NoArgsConstructor
|
||||
public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
|
||||
|
||||
protected ParallelInference parallelInference;
|
||||
protected Model model;
|
||||
protected boolean parallelEnabled = true;
|
||||
|
||||
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
||||
super(inferenceAdapter, serializer, deserializer);
|
||||
this.parallelInference = parallelInference;
|
||||
this.model = null;
|
||||
this.parallelEnabled = true;
|
||||
}
|
||||
|
||||
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
|
||||
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
|
||||
super(inferenceAdapter, serializer, deserializer);
|
||||
this.model = model;
|
||||
this.parallelInference = null;
|
||||
this.parallelEnabled = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
|
||||
String processorReturned = "";
|
||||
String path = request.getPathInfo();
|
||||
if (path.equals(SERVING_ENDPOINT)) {
|
||||
val contentType = request.getContentType();
|
||||
if (validateRequest(request,response)) {
|
||||
val stream = request.getInputStream();
|
||||
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
|
||||
char[] charBuffer = new char[128];
|
||||
int bytesRead = -1;
|
||||
val buffer = new StringBuilder();
|
||||
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
|
||||
buffer.append(charBuffer, 0, bytesRead);
|
||||
}
|
||||
val requestString = buffer.toString();
|
||||
|
||||
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
|
||||
|
||||
O result = null;
|
||||
if (parallelEnabled) {
|
||||
// process result
|
||||
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
}
|
||||
else {
|
||||
synchronized(this) {
|
||||
if (model instanceof ComputationGraph)
|
||||
result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
|
||||
else if (model instanceof MultiLayerNetwork) {
|
||||
Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
|
||||
"Input data for MultilayerNetwork is invalid!");
|
||||
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
|
||||
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
|
||||
}
|
||||
}
|
||||
}
|
||||
processorReturned = serializer.serialize(result);
|
||||
}
|
||||
} else {
|
||||
// we return error otherwise
|
||||
sendError(request.getRequestURI(), response);
|
||||
}
|
||||
try {
|
||||
val out = response.getWriter();
|
||||
out.write(processorReturned);
|
||||
} catch (IOException e) {
|
||||
log.error(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates servlet to serve models
|
||||
*
|
||||
* @param <I> type of Input class
|
||||
* @param <O> type of Output class
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author astoyakin
|
||||
*/
|
||||
public static class Builder<I,O> {
|
||||
|
||||
private ParallelInference pi;
|
||||
private Model model;
|
||||
|
||||
private InferenceAdapter<I, O> inferenceAdapter;
|
||||
private JsonSerializer<O> serializer;
|
||||
private JsonDeserializer<I> deserializer;
|
||||
private int port;
|
||||
private boolean parallelEnabled = true;
|
||||
|
||||
public Builder(@NonNull ParallelInference pi) {
|
||||
this.pi = pi;
|
||||
}
|
||||
|
||||
public Builder(@NonNull Model model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is required to specify serializer
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> serializer(@NonNull JsonSerializer<O> serializer) {
|
||||
this.serializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows to specify deserializer
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> deserializer(@NonNull JsonDeserializer<I> deserializer) {
|
||||
this.deserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows to specify port
|
||||
*
|
||||
* @param port
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> port(int port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method activates parallel inference
|
||||
*
|
||||
* @param parallelEnabled
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> parallelEnabled(boolean parallelEnabled) {
|
||||
this.parallelEnabled = parallelEnabled;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DL4jServlet<I,O> build() {
|
||||
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer) :
|
||||
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,392 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.ModelAdapter;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.parallelism.ParallelInference;
|
||||
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
||||
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.adapters.InputAdapter;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.remote.SameDiffJsonModelServer;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This class provides JSON-based model serving ability for Deeplearning4j/SameDiff models
|
||||
*
|
||||
* Server url will be http://0.0.0.0:{port}>/v1/serving
|
||||
* Server only accepts POST requests
|
||||
*
|
||||
* @param <I> type of the input class, i.e. String
|
||||
* @param <O> type of the output class, i.e. Sentiment
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author astoyakin
|
||||
*/
|
||||
public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
|
||||
|
||||
// all serving goes through ParallelInference
|
||||
protected ParallelInference parallelInference;
|
||||
|
||||
|
||||
protected ModelAdapter<O> modelAdapter;
|
||||
|
||||
// actual models
|
||||
protected ComputationGraph cgModel;
|
||||
protected MultiLayerNetwork mlnModel;
|
||||
|
||||
// service stuff
|
||||
protected InferenceMode inferenceMode;
|
||||
protected int numWorkers;
|
||||
|
||||
protected boolean enabledParallel = true;
|
||||
|
||||
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
|
||||
super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
|
||||
this.cgModel = cgModel;
|
||||
this.inferenceMode = inferenceMode;
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
|
||||
this.mlnModel = mlnModel;
|
||||
this.inferenceMode = inferenceMode;
|
||||
this.numWorkers = numWorkers;
|
||||
}
|
||||
|
||||
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port) {
|
||||
super(inferenceAdapter, serializer, deserializer, port);
|
||||
|
||||
this.parallelInference = pi;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method stops server
|
||||
*
|
||||
* @throws Exception
|
||||
*/
|
||||
@Override
|
||||
public void stop() throws Exception {
|
||||
if (parallelInference != null)
|
||||
parallelInference.shutdown();
|
||||
super.stop();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method starts server
|
||||
* @throws Exception
|
||||
*/
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
// if we're just serving sdModel - we'll just call super. no dl4j functionality required in this case
|
||||
if (sdModel != null) {
|
||||
super.start();
|
||||
return;
|
||||
}
|
||||
Preconditions.checkArgument(cgModel != null || mlnModel != null, "Model serving requires either MultilayerNetwork or ComputationGraph defined");
|
||||
|
||||
val model = cgModel != null ? (Model) cgModel : (Model) mlnModel;
|
||||
// PI construction is optional, since we can have it defined
|
||||
if (enabledParallel) {
|
||||
if (parallelInference == null) {
|
||||
Preconditions.checkArgument(numWorkers >= 1, "Number of workers should be >= 1, got " + numWorkers + " instead");
|
||||
|
||||
parallelInference = new ParallelInference.Builder(model)
|
||||
.inferenceMode(inferenceMode)
|
||||
.workers(numWorkers)
|
||||
.loadBalanceMode(LoadBalanceMode.FIFO)
|
||||
.batchLimit(16)
|
||||
.queueLimit(128)
|
||||
.build();
|
||||
}
|
||||
servingServlet = new DL4jServlet.Builder<I, O>(parallelInference)
|
||||
.parallelEnabled(true)
|
||||
.serializer(serializer)
|
||||
.deserializer(deserializer)
|
||||
.inferenceAdapter(inferenceAdapter)
|
||||
.build();
|
||||
}
|
||||
else {
|
||||
servingServlet = new DL4jServlet.Builder<I, O>(model)
|
||||
.parallelEnabled(false)
|
||||
.serializer(serializer)
|
||||
.deserializer(deserializer)
|
||||
.inferenceAdapter(inferenceAdapter)
|
||||
.build();
|
||||
}
|
||||
start(port, servingServlet);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates servlet to serve different types of models
|
||||
*
|
||||
* @param <I> type of Input class
|
||||
* @param <O> type of Output class
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
* @author astoyakin
|
||||
*/
|
||||
public static class Builder<I,O> {
|
||||
|
||||
private SameDiff sdModel;
|
||||
private ComputationGraph cgModel;
|
||||
private MultiLayerNetwork mlnModel;
|
||||
private ParallelInference pi;
|
||||
|
||||
private String[] orderedInputNodes;
|
||||
private String[] orderedOutputNodes;
|
||||
|
||||
private InferenceAdapter<I, O> inferenceAdapter;
|
||||
private JsonSerializer<O> serializer;
|
||||
private JsonDeserializer<I> deserializer;
|
||||
|
||||
private InputAdapter<I> inputAdapter;
|
||||
private OutputAdapter<O> outputAdapter;
|
||||
|
||||
private int port;
|
||||
|
||||
private boolean parallelMode = true;
|
||||
|
||||
// these fields actually require defaults
|
||||
private InferenceMode inferenceMode = InferenceMode.BATCHED;
|
||||
private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||
|
||||
public Builder(@NonNull SameDiff sdModel) {
|
||||
this.sdModel = sdModel;
|
||||
}
|
||||
|
||||
public Builder(@NonNull MultiLayerNetwork mlnModel) {
|
||||
this.mlnModel = mlnModel;
|
||||
}
|
||||
|
||||
public Builder(@NonNull ComputationGraph cgModel) {
|
||||
this.cgModel = cgModel;
|
||||
}
|
||||
|
||||
public Builder(@NonNull ParallelInference pi) {
|
||||
this.pi = pi;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type
|
||||
* @param inferenceAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify InputAdapter to be used for inference
|
||||
*
|
||||
* PLEASE NOTE: This method is optional, and will require OutputAdapter<O> defined
|
||||
* @param inputAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
|
||||
this.inputAdapter = inputAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify OutputtAdapter to be used for inference
|
||||
*
|
||||
* PLEASE NOTE: This method is optional, and will require InputAdapter<I> defined
|
||||
* @param outputAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
|
||||
this.outputAdapter = outputAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify serializer
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> outputSerializer(@NonNull JsonSerializer<O> serializer) {
|
||||
this.serializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify deserializer
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inputDeserializer(@NonNull JsonDeserializer<I> deserializer) {
|
||||
this.deserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
|
||||
*
|
||||
* @param inferenceMode
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inferenceMode(@NonNull InferenceMode inferenceMode) {
|
||||
this.inferenceMode = inferenceMode;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify number of worker threads for ParallelInference
|
||||
*
|
||||
* @param numWorkers
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> numWorkers(int numWorkers) {
|
||||
this.numWorkers = numWorkers;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedInputNodes(String... args) {
|
||||
orderedInputNodes = args;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedInputNodes(@NonNull List<String> args) {
|
||||
orderedInputNodes = args.toArray(new String[args.size()]);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify output nodes
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedOutputNodes(String... args) {
|
||||
Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element");
|
||||
orderedOutputNodes = args;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify output nodes
|
||||
*
|
||||
* PLEASE NOTE: this argument only used for SameDiff models
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedOutputNodes(@NonNull List<String> args) {
|
||||
Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element");
|
||||
orderedOutputNodes = args.toArray(new String[args.size()]);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify http port
|
||||
*
|
||||
* PLEASE NOTE: port must be free and be in range regular TCP/IP ports range
|
||||
* @param port
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> port(int port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method switches on ParallelInference usage
|
||||
* @param - true - to use ParallelInference, false - to use ComputationGraph or
|
||||
* MultiLayerNetwork directly
|
||||
*
|
||||
* PLEASE NOTE: this doesn't apply to SameDiff models
|
||||
*
|
||||
* @throws Exception
|
||||
*/
|
||||
public Builder<I,O> parallelMode(boolean enable) {
|
||||
this.parallelMode = enable;
|
||||
return this;
|
||||
}
|
||||
|
||||
public JsonModelServer<I,O> build() {
|
||||
if (inferenceAdapter == null) {
|
||||
if (inputAdapter != null && outputAdapter != null) {
|
||||
inferenceAdapter = new InferenceAdapter<I, O>() {
|
||||
@Override
|
||||
public MultiDataSet apply(I input) {
|
||||
return inputAdapter.apply(input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public O apply(INDArray... outputs) {
|
||||
return outputAdapter.apply(outputs);
|
||||
}
|
||||
};
|
||||
} else
|
||||
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
||||
}
|
||||
|
||||
if (sdModel != null) {
|
||||
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined");
|
||||
return new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
} else if (cgModel != null)
|
||||
return new JsonModelServer<I,O>(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
||||
else if (mlnModel != null)
|
||||
return new JsonModelServer<I,O>(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
|
||||
else if (pi != null)
|
||||
return new JsonModelServer<I,O>(pi, inferenceAdapter, serializer, deserializer, port);
|
||||
else
|
||||
throw new IllegalStateException("No models were defined for JsonModelServer");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,749 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.parallelism.inference.InferenceMode;
|
||||
import org.deeplearning4j.remote.helpers.House;
|
||||
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
|
||||
import org.deeplearning4j.remote.helpers.PredictedPrice;
|
||||
import org.junit.After;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
import org.nd4j.remote.clients.JsonRemoteInference;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
|
||||
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
public class JsonModelServerTest {
|
||||
private static final MultiLayerNetwork model;
|
||||
private final int PORT = 18080;
|
||||
|
||||
static {
|
||||
val conf = new NeuralNetConfiguration.Builder()
|
||||
.seed(119)
|
||||
.updater(new Adam(0.119f))
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(10).build())
|
||||
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nIn(10).nOut(1).build())
|
||||
.build();
|
||||
|
||||
model = new MultiLayerNetwork(conf);
|
||||
model.init();
|
||||
}
|
||||
|
||||
@After
|
||||
public void pause() throws Exception {
|
||||
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
|
||||
TimeUnit.SECONDS.sleep(2);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testStartStopParallel() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.port(PORT+1)
|
||||
.build();
|
||||
try {
|
||||
serverDL.start();
|
||||
serverSD.start();
|
||||
|
||||
val clientDL = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
PredictedPrice price = clientDL.predict(house);
|
||||
long timeStart = System.currentTimeMillis();
|
||||
price = clientDL.predict(house);
|
||||
long timeStop = System.currentTimeMillis();
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
|
||||
|
||||
val clientSD = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
|
||||
.build();
|
||||
|
||||
PredictedPrice price2 = clientSD.predict(house);
|
||||
timeStart = System.currentTimeMillis();
|
||||
price = clientSD.predict(house);
|
||||
timeStop = System.currentTimeMillis();
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 3.0, price.getPrice(), 1e-5);
|
||||
|
||||
}
|
||||
finally {
|
||||
serverSD.stop();
|
||||
serverDL.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStartStopSequential() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.port(PORT+1)
|
||||
.build();
|
||||
|
||||
serverDL.start();
|
||||
serverDL.stop();
|
||||
|
||||
serverSD.start();
|
||||
serverSD.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basicServingTestForSD() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(sd)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
price = client.predict(house);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
assertNotNull(price);
|
||||
assertEquals((float) district + 1.0f, price.getPrice(), 1e-5);
|
||||
}
|
||||
finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basicServingTestForDLSynchronized() throws Exception {
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(INPLACE)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house1 = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
House house2 = House.builder().area(50).bathrooms(1).bedrooms(2).district(district).build();
|
||||
House house3 = House.builder().area(80).bathrooms(1).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house1);
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
PredictedPrice price1 = client.predict(house1);
|
||||
PredictedPrice price2 = client.predict(house2);
|
||||
PredictedPrice price3 = client.predict(house3);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
|
||||
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void basicServingTestForDL() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.numWorkers(1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
price = client.predict(house);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
|
||||
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserialization_1() {
|
||||
String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}";
|
||||
val deserializer = new House.HouseDeserializer();
|
||||
val result = deserializer.deserialize(request);
|
||||
assertEquals(2, result.getDistrict());
|
||||
assertEquals(100, result.getArea());
|
||||
assertEquals(2, result.getBathrooms());
|
||||
assertEquals(3, result.getBedrooms());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserialization_2() {
|
||||
String request = "{\"price\":1}";
|
||||
val deserializer = new PredictedPrice.PredictedPriceDeserializer();
|
||||
val result = deserializer.deserialize(request);
|
||||
assertEquals(1.0, result.getPrice(), 1e-4);
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void negativeServingTest_1() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(null)
|
||||
.port(18080)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test //(expected = NullPointerException.class)
|
||||
public void negativeServingTest_2() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IOException.class)
|
||||
public void negativeServingTest_3() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(1)
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
|
||||
@Override
|
||||
public PredictedPrice deserialize(String json) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.endpointAddress("http://localhost:18080/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void asyncServingTest() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(1)
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
Future<PredictedPrice> price = client.predictAsync(house);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) 0.421444, price.get().getPrice(), 1e-5);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
}
|
||||
finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void negativeAsyncTest() throws Exception {
|
||||
|
||||
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.inferenceMode(InferenceMode.BATCHED)
|
||||
.numWorkers(1)
|
||||
.port(PORT)
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
|
||||
// Fake deserializer to test failure
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
|
||||
@Override
|
||||
public PredictedPrice deserialize(String json) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
try {
|
||||
Future<PredictedPrice> price = client.predictAsync(house);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
} catch (ExecutionException e) {
|
||||
assertTrue(e.getMessage().contains("Deserialization failed"));
|
||||
}
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSameDiffMnist() throws Exception {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
|
||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
|
||||
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(sd)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("in")
|
||||
.orderedOutputNodes("softmax")
|
||||
.port(PORT+1)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try{
|
||||
server.start();
|
||||
for( int i=0; i<10; i++ ){
|
||||
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28*28);
|
||||
INDArray exp = sd.output(Collections.singletonMap("in", f), "softmax").get("softmax");
|
||||
float[] fArr = f.toFloatVector();
|
||||
int out = client.predict(fArr);
|
||||
assertEquals(exp.argMax().getInt(0), out);
|
||||
}
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMlnMnist() throws Exception {
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.list()
|
||||
.layer(new DenseLayer.Builder().nIn(784).nOut(10).build())
|
||||
.layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(net)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("in")
|
||||
.orderedOutputNodes("softmax")
|
||||
.port(PORT + 1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(2)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28 * 28);
|
||||
INDArray exp = net.output(f);
|
||||
float[] fArr = f.toFloatVector();
|
||||
int out = client.predict(fArr);
|
||||
assertEquals(exp.argMax().getInt(0), out);
|
||||
}
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompGraph() throws Exception {
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.graphBuilder()
|
||||
.addInputs("input1", "input2")
|
||||
.addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1")
|
||||
.addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2")
|
||||
.addVertex("merge", new MergeVertex(), "L1", "L2")
|
||||
.addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge")
|
||||
.setOutputs("out")
|
||||
.build();
|
||||
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.init();
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(net)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("in")
|
||||
.orderedOutputNodes("softmax")
|
||||
.port(PORT + 1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(2)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
//client.predict(new float[]{0.0f, 1.0f, 2.0f});
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompGraph_1() throws Exception {
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.updater(new Sgd(0.01))
|
||||
.graphBuilder()
|
||||
.addInputs("input")
|
||||
.addLayer("L1", new DenseLayer.Builder().nIn(8).nOut(4).build(), "input")
|
||||
.addLayer("out1", new OutputLayer.Builder()
|
||||
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.nIn(4).nOut(3).build(), "L1")
|
||||
.addLayer("out2", new OutputLayer.Builder()
|
||||
.lossFunction(LossFunctions.LossFunction.MSE)
|
||||
.nIn(4).nOut(2).build(), "L1")
|
||||
.setOutputs("out1","out2")
|
||||
.build();
|
||||
|
||||
final ComputationGraph net = new ComputationGraph(conf);
|
||||
net.init();
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(net)
|
||||
.outputSerializer( new IntSerde())
|
||||
.inputDeserializer(new FloatSerde())
|
||||
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
|
||||
@Override
|
||||
public MultiDataSet apply(float[] input) {
|
||||
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer apply(INDArray... nnOutput) {
|
||||
return nnOutput[0].argMax().getInt(0);
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("input")
|
||||
.orderedOutputNodes("out")
|
||||
.port(PORT + 1)
|
||||
.inferenceMode(SEQUENTIAL)
|
||||
.numWorkers(2)
|
||||
.parallelMode(false)
|
||||
.build();
|
||||
|
||||
val client = JsonRemoteInference.<float[], Integer>builder()
|
||||
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
|
||||
.outputDeserializer(new IntSerde())
|
||||
.inputSerializer( new FloatSerde())
|
||||
.build();
|
||||
|
||||
try {
|
||||
server.start();
|
||||
val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
|
||||
assertNotNull(result);
|
||||
} catch (Exception e){
|
||||
e.printStackTrace();
|
||||
throw e;
|
||||
} finally {
|
||||
server.stop();
|
||||
}
|
||||
}
|
||||
|
||||
private static class FloatSerde implements JsonSerializer<float[]>, JsonDeserializer<float[]>{
|
||||
private final ObjectMapper om = new ObjectMapper();
|
||||
|
||||
@Override
|
||||
public float[] deserialize(String json) {
|
||||
try {
|
||||
return om.readValue(json, FloatHolder.class).getFloats();
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(float[] o) {
|
||||
try{
|
||||
return om.writeValueAsString(new FloatHolder(o));
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
//Use float holder so Jackson does ser/de properly (no "{}" otherwise)
|
||||
@AllArgsConstructor @NoArgsConstructor @Data
|
||||
private static class FloatHolder {
|
||||
private float[] floats;
|
||||
}
|
||||
}
|
||||
|
||||
private static class IntSerde implements JsonSerializer<Integer>, JsonDeserializer<Integer> {
|
||||
private final ObjectMapper om = new ObjectMapper();
|
||||
|
||||
@Override
|
||||
public Integer deserialize(String json) {
|
||||
try {
|
||||
return om.readValue(json, Integer.class);
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(Integer o) {
|
||||
try{
|
||||
return om.writeValueAsString(o);
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote;
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.http.client.methods.HttpGet;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.impl.client.HttpClientBuilder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ServletTest {
|
||||
|
||||
private JsonModelServer server;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
server = new JsonModelServer.Builder<String,String>(sd)
|
||||
.port(8080)
|
||||
.inferenceAdapter(new InferenceAdapter<String, String>() {
|
||||
@Override
|
||||
public MultiDataSet apply(String input) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String apply(INDArray... nnOutput) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.outputSerializer(new JsonSerializer<String>() {
|
||||
@Override
|
||||
public String serialize(String o) {
|
||||
return "";
|
||||
}
|
||||
})
|
||||
.inputDeserializer(new JsonDeserializer<String>() {
|
||||
@Override
|
||||
public String deserialize(String json) {
|
||||
return "";
|
||||
}
|
||||
})
|
||||
.orderedInputNodes("input")
|
||||
.orderedOutputNodes("output")
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
//server.join();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
server.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getEndpoints() throws IOException {
|
||||
val request = new HttpGet( "http://localhost:8080/v1" );
|
||||
request.setHeader("Content-type", "application/json");
|
||||
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(200, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testContentTypeGet() throws IOException {
|
||||
val request = new HttpGet( "http://localhost:8080/v1" );
|
||||
request.setHeader("Content-type", "text/plain");
|
||||
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(415, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testContentTypePost() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving");
|
||||
request.setHeader("Content-type", "text/plain");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(415, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void postForServing() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving");
|
||||
request.setHeader("Content-type", "application/json");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(500, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotFoundPost() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving/some");
|
||||
request.setHeader("Content-type", "application/json");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(404, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotFoundGet() throws Exception {
|
||||
val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" );
|
||||
requestGet.setHeader("Content-type", "application/json");
|
||||
|
||||
val responseGet = HttpClientBuilder.create().build().execute( requestGet );
|
||||
assertEquals(404, responseGet.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote.helpers;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import lombok.*;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class House {
|
||||
private int district;
|
||||
private int bedrooms;
|
||||
private int bathrooms;
|
||||
private int area;
|
||||
|
||||
|
||||
public static class HouseSerializer implements JsonSerializer<House> {
|
||||
@Override
|
||||
public String serialize(@NonNull House o) {
|
||||
return new Gson().toJson(o);
|
||||
}
|
||||
}
|
||||
|
||||
public static class HouseDeserializer implements JsonDeserializer<House> {
|
||||
@Override
|
||||
public House deserialize(@NonNull String json) {
|
||||
return new Gson().fromJson(json, House.class);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote.helpers;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
@Slf4j
|
||||
public class HouseToPredictedPriceAdapter implements InferenceAdapter<House, PredictedPrice> {
|
||||
|
||||
@Override
|
||||
public MultiDataSet apply(@NonNull House input) {
|
||||
// we just create vector array with shape[4] and assign it's value to the district value
|
||||
return new MultiDataSet(Nd4j.create(DataType.FLOAT, 1, 4).assign(input.getDistrict()), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PredictedPrice apply(INDArray... nnOutput) {
|
||||
return new PredictedPrice(nnOutput[0].getFloat(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.remote.helpers;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class PredictedPrice {
|
||||
private float price;
|
||||
|
||||
public static class PredictedPriceSerializer implements JsonSerializer<PredictedPrice> {
|
||||
@Override
|
||||
public String serialize(@NonNull PredictedPrice o) {
|
||||
return new Gson().toJson(o);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PredictedPriceDeserializer implements JsonDeserializer<PredictedPrice> {
|
||||
@Override
|
||||
public PredictedPrice deserialize(@NonNull String json) {
|
||||
return new Gson().fromJson(json, PredictedPrice.class);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<configuration>
|
||||
|
||||
|
||||
|
||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||
<file>logs/application.log</file>
|
||||
<encoder>
|
||||
<pattern>%date - [%level] - from %logger in %thread
|
||||
%n%message%n%xException%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<logger name="org.eclipse.jetty" level="WARN" />
|
||||
<logger name="org.apache.catalina.core" level="WARN" />
|
||||
<logger name="org.springframework" level="WARN" />
|
||||
<logger name="org.nd4j" level="DEBUG" />
|
||||
<logger name="org.deeplearning4j" level="INFO" />
|
||||
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
|
||||
</configuration>
|
|
@ -0,0 +1,30 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<modules>
|
||||
<module>deeplearning4j-json-server</module>
|
||||
</modules>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>deeplearning4j-remote</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<name>deeplearning4j-remote</name>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -21,7 +21,6 @@ import lombok.*;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.ModelAdapter;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
|
|
|
@ -22,7 +22,6 @@ import lombok.val;
|
|||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.nn.api.ModelAdapter;
|
||||
import org.deeplearning4j.nn.api.OutputAdapter;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
|
|
|
@ -144,6 +144,7 @@
|
|||
<module>dl4j-perf</module>
|
||||
<module>dl4j-integration-tests</module>
|
||||
<module>deeplearning4j-common</module>
|
||||
<module>deeplearning4j-remote</module>
|
||||
</modules>
|
||||
|
||||
<dependencyManagement>
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.adapters;
|
||||
|
||||
/**
|
||||
* This interface describes methods needed to convert custom JVM objects to INDArrays, suitable for feeding neural networks
|
||||
*
|
||||
* @param <I> type of the Input for the model. I.e. String for raw text
|
||||
* @param <O> type of the Output for the model, I.e. Sentiment, for Text->Sentiment extraction
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public interface InferenceAdapter<I, O> extends InputAdapter<I>, OutputAdapter<O> {
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.adapters;
|
||||
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
|
||||
/**
|
||||
* This interface describes method for transformation from object of type I to MultiDataSet.
|
||||
*
|
||||
*/
|
||||
public interface InputAdapter<I> {
|
||||
/**
|
||||
* This method converts input object to MultiDataSet
|
||||
* @param input
|
||||
* @return
|
||||
*/
|
||||
MultiDataSet apply(I input);
|
||||
}
|
|
@ -14,15 +14,15 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.api;
|
||||
package org.nd4j.adapters;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* This interface describes entity used to conver neural network output to specified class.
|
||||
* I.e. INDArray -> int[] on the fly.
|
||||
* This interface describes entity used to convert neural network output to specified class.
|
||||
* I.e. INDArray -> int[] or INDArray -> Sentiment on the fly
|
||||
*
|
||||
* PLEASE NOTE: Implementation will be used in workspace environment to avoid additional allocations during inference.
|
||||
* This means you shouldn't store or return the INDArrays passed to OutputAdapter.apply(INDArray...) directly.
|
|
@ -4036,6 +4036,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
|
|||
public native void printBuffer();
|
||||
public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/);
|
||||
|
||||
/**
|
||||
* print element by element consequently in a way they (elements) are stored in physical memory
|
||||
*/
|
||||
public native void printLinearBuffer();
|
||||
|
||||
/**
|
||||
* prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status
|
||||
*/
|
||||
|
@ -7047,9 +7052,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
@Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength);
|
||||
|
||||
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength);
|
||||
|
||||
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder);
|
||||
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder);
|
||||
|
@ -7894,10 +7899,6 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* Returns the prod of the data
|
||||
* up to the given length
|
||||
*/
|
||||
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length);
|
||||
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length);
|
||||
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length);
|
||||
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length);
|
||||
|
@ -8745,10 +8746,6 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* @param originalTadNum the tad number for the reduced version of the problem
|
||||
*/
|
||||
|
||||
/**
|
||||
* Returns the prod of the data
|
||||
* up to the given length
|
||||
*/
|
||||
|
||||
/**
|
||||
* Returns the prod of the data
|
||||
|
|
|
@ -4036,6 +4036,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
|
|||
public native void printBuffer();
|
||||
public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/);
|
||||
|
||||
/**
|
||||
* print element by element consequently in a way they (elements) are stored in physical memory
|
||||
*/
|
||||
public native void printLinearBuffer();
|
||||
|
||||
/**
|
||||
* prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status
|
||||
*/
|
||||
|
@ -7047,9 +7052,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
@Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength);
|
||||
|
||||
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength);
|
||||
|
||||
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder);
|
||||
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder);
|
||||
|
@ -7894,10 +7899,6 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* Returns the prod of the data
|
||||
* up to the given length
|
||||
*/
|
||||
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length);
|
||||
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length);
|
||||
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length);
|
||||
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length);
|
||||
|
@ -8745,10 +8746,6 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* @param originalTadNum the tad number for the reduced version of the problem
|
||||
*/
|
||||
|
||||
/**
|
||||
* Returns the prod of the data
|
||||
* up to the given length
|
||||
*/
|
||||
|
||||
/**
|
||||
* Returns the prod of the data
|
||||
|
|
|
@ -50,6 +50,8 @@
|
|||
<artifactId>httpmime</artifactId>
|
||||
<version>${httpmime.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
|
|
|
@ -54,11 +54,13 @@
|
|||
<artifactId>httpmime</artifactId>
|
||||
<version>${httpmime.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<version>${unirest.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-jackson</artifactId>
|
||||
|
|
|
@ -21,13 +21,15 @@ import com.beust.jcommander.Parameter;
|
|||
import com.beust.jcommander.ParameterException;
|
||||
import com.beust.jcommander.Parameters;
|
||||
import com.google.common.primitives.Ints;
|
||||
import com.mashape.unirest.http.HttpResponse;
|
||||
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import io.aeron.Aeron;
|
||||
import io.aeron.driver.MediaDriver;
|
||||
import io.aeron.driver.ThreadingMode;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import org.agrona.CloseHelper;
|
||||
import org.agrona.concurrent.BusySpinIdleStrategy;
|
||||
import org.json.JSONObject;
|
||||
|
@ -49,7 +51,6 @@ import org.nd4j.parameterserver.updater.SoftSyncParameterUpdater;
|
|||
import org.nd4j.parameterserver.updater.SynchronousParameterUpdater;
|
||||
import org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage;
|
||||
import org.nd4j.parameterserver.util.CheckSocket;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -342,7 +343,7 @@ public class ParameterServerSubscriber implements AutoCloseable {
|
|||
JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState));
|
||||
String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort,
|
||||
streamId);
|
||||
HttpResponse<String> entity = Unirest.post(url).header("Content-Type", "application/json")
|
||||
val entity = Unirest.post(url).header("Content-Type", "application/json")
|
||||
.body(jsonObject).asString();
|
||||
} catch (Exception e) {
|
||||
failCount.incrementAndGet();
|
||||
|
|
|
@ -19,13 +19,13 @@
|
|||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>nd4j-serde</artifactId>
|
||||
<artifactId>nd4j-remote</artifactId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>nd4j-grpc</artifactId>
|
||||
<artifactId>nd4j-grpc-client</artifactId>
|
||||
|
||||
<name>nd4j-grpc</name>
|
||||
<!-- FIXME change it to the project's website -->
|
|
@ -0,0 +1,70 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<parent>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-remote</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>nd4j-json-client</artifactId>
|
||||
|
||||
<name>nd4j-json-client</name>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.7</maven.compiler.source>
|
||||
<maven.compiler.target>1.7</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<version>${unirest.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>jackson</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>testresources</id>
|
||||
<!-- Put nd4j-native in profile so that CUDA-only builds succeed -->
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<classifier>${javacpp.platform}</classifier>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -0,0 +1,153 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients;
|
||||
|
||||
import com.mashape.unirest.http.HttpResponse;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import com.mashape.unirest.http.exceptions.UnirestException;
|
||||
import lombok.Builder;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.json.JSONObject;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
/**
|
||||
* This class provides remote inference functionality via JSON-powered REST APIs.
|
||||
*
|
||||
* Basically we assume that there's remote JSON server available (on bare metal or in k8s/swarm/whatever cluster), and with proper serializers/deserializers provided we can issue REST requests and get back responses.
|
||||
* So, in this way application logic can be separated from DL logic.
|
||||
*
|
||||
* You just need to provide serializer/deserializer and address of the REST server, i.e. "http://model:8080/v1/serving"
|
||||
*
|
||||
* @param <I> type of the input class, i.e. String
|
||||
* @param <O> type of the output class, i.e. Sentiment
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Slf4j
|
||||
public class JsonRemoteInference<I, O> {
|
||||
private String endpointAddress;
|
||||
private JsonSerializer<I> serializer;
|
||||
private JsonDeserializer<O> deserializer;
|
||||
|
||||
@Builder
|
||||
public JsonRemoteInference(@NonNull String endpointAddress, @NonNull JsonSerializer<I> inputSerializer, @NonNull JsonDeserializer<O> outputDeserializer) {
|
||||
this.endpointAddress = endpointAddress;
|
||||
this.serializer = inputSerializer;
|
||||
this.deserializer = outputDeserializer;
|
||||
}
|
||||
|
||||
private O processResponse(HttpResponse<String> response) throws IOException {
|
||||
if (response.getStatus() != 200)
|
||||
throw new IOException("Inference request returned bad error code: " + response.getStatus());
|
||||
|
||||
O result = deserializer.deserialize(response.getBody());
|
||||
if (result == null) {
|
||||
throw new IOException("Deserialization failed!");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method does remote inference in a blocking way
|
||||
*
|
||||
* @param input
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public O predict(I input) throws IOException {
|
||||
try {
|
||||
val stringResult = Unirest.post(endpointAddress)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.body(new JSONObject(serializer.serialize(input))).asString();
|
||||
|
||||
return processResponse(stringResult);
|
||||
} catch (UnirestException e) {
|
||||
throw new IOException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method does remote inference in asynchronous way, returning Future instead
|
||||
* @param input
|
||||
* @return
|
||||
*/
|
||||
public Future<O> predictAsync(I input) {
|
||||
val stringResult = Unirest.post(endpointAddress)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "application/json")
|
||||
.body(new JSONObject(serializer.serialize(input))).asStringAsync();
|
||||
return new InferenceFuture(stringResult);
|
||||
}
|
||||
|
||||
/**
|
||||
* This class holds a Future of the object returned by remote inference server
|
||||
*/
|
||||
private class InferenceFuture implements Future<O> {
|
||||
private Future<HttpResponse<String>> unirestFuture;
|
||||
|
||||
private InferenceFuture(@NonNull Future<HttpResponse<String>> future) {
|
||||
this.unirestFuture = future;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean cancel(boolean mayInterruptIfRunning) {
|
||||
return unirestFuture.cancel(mayInterruptIfRunning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCancelled() {
|
||||
return unirestFuture.isCancelled();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDone() {
|
||||
return unirestFuture.isDone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public O get() throws InterruptedException, ExecutionException {
|
||||
val stringResult = unirestFuture.get();
|
||||
|
||||
try {
|
||||
return processResponse(stringResult);
|
||||
} catch (IOException e) {
|
||||
throw new ExecutionException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public O get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
|
||||
val stringResult = unirestFuture.get(timeout, unit);
|
||||
|
||||
try {
|
||||
return processResponse(stringResult);
|
||||
} catch (IOException e) {
|
||||
throw new ExecutionException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde;
|
||||
|
||||
/**
|
||||
* This interface describes basic JSON deserializer interface used for JsonRemoteInference
|
||||
* @param <T> type of the deserializable class
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public interface JsonDeserializer<T> {
|
||||
|
||||
/**
|
||||
* This method serializes given object into JSON-string
|
||||
* @param json string containing JSON representation of the object
|
||||
* @return
|
||||
*/
|
||||
T deserialize(String json);
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde;
|
||||
|
||||
/**
|
||||
* This interface describes basic JSON serializer interface used for JsonRemoteInference
|
||||
* @param <T> type of the serializable class
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public interface JsonSerializer<T> {
|
||||
|
||||
/**
|
||||
* This method serializes given object into JSON-string
|
||||
*
|
||||
* @param o object to be serialized
|
||||
* @return
|
||||
*/
|
||||
String serialize(T o);
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public abstract class AbstractSerDe<T> implements JsonDeserializer<T>, JsonSerializer<T> {
|
||||
protected ObjectMapper objectMapper = new ObjectMapper();
|
||||
|
||||
|
||||
protected String serializeClass(@NonNull T obj) {
|
||||
try {
|
||||
return objectMapper.writeValueAsString(obj);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
protected T deserializeClass(@NonNull String json, @NonNull Class<T> cls) {
|
||||
try {
|
||||
return objectMapper.readValue(json, cls);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
|
||||
/**
|
||||
* This class provides JSON ser/de for Java Boolean. Single value only.
|
||||
*/
|
||||
public class BooleanSerde extends AbstractSerDe<Boolean> {
|
||||
@Override
|
||||
public Boolean deserialize(@NonNull String json) {
|
||||
return deserializeClass(json, Boolean.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(@NonNull Boolean o) {
|
||||
return serializeClass(o);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
import lombok.*;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* This class provides JSON ser/de for Java double[]
|
||||
*/
|
||||
public class DoubleArraySerde extends AbstractSerDe<double[]> {
|
||||
|
||||
@Override
|
||||
public String serialize(@NonNull double[] data) {
|
||||
return serializeClass(data);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] deserialize(@NonNull String json) {
|
||||
return deserializeClass(json, double[].class);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
|
||||
/**
|
||||
* This class provides JSON ser/de for Java Double. Single value only.
|
||||
*/
|
||||
public class DoubleSerde extends AbstractSerDe<Double> {
|
||||
@Override
|
||||
public Double deserialize(@NonNull String json) {
|
||||
return deserializeClass(json, Double.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(@NonNull Double o) {
|
||||
return serializeClass(o);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
|
||||
import lombok.*;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
||||
/**
|
||||
* This class provides JSON ser/de for Java float[]
|
||||
*/
|
||||
public class FloatArraySerde extends AbstractSerDe<float[]> {
|
||||
|
||||
@Override
|
||||
public String serialize(@NonNull float[] data) {
|
||||
return serializeClass(data);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] deserialize(@NonNull String json) {
|
||||
return deserializeClass(json, float[].class);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
|
||||
/**
|
||||
* This class provides JSON ser/de for Java Float. Single value only.
|
||||
*/
|
||||
public class FloatSerde extends AbstractSerDe<Float> {
|
||||
@Override
|
||||
public Float deserialize(@NonNull String json) {
|
||||
return deserializeClass(json, Float.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(@NonNull Float o) {
|
||||
return serializeClass(o);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
|
||||
/**
|
||||
* This class provides JSON ser/de for Java Integer. Single value only.
|
||||
*/
|
||||
public class IntegerSerde extends AbstractSerDe<Integer> {
|
||||
@Override
|
||||
public Integer deserialize(@NonNull String json) {
|
||||
return deserializeClass(json, Integer.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String serialize(@NonNull Integer o) {
|
||||
return serializeClass(o);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.clients.serde.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* This class provides fake JSON serializer/deserializer functionality for String.
|
||||
* It doesn't put any JSON-specific bits into actual string
|
||||
*/
|
||||
public class StringSerde extends AbstractSerDe<String> {
|
||||
|
||||
@Override
|
||||
public String serialize(@NonNull String data) {
|
||||
return serializeClass(data);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String deserialize(@NonNull String json) {
|
||||
return deserializeClass(json, String.class);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
## SameDiff model serving
|
||||
|
||||
This modules provides JSON-based serving of SameDiff models
|
||||
|
||||
## Example
|
||||
|
||||
First of all we'll create server instance. Most probably you'll do it in application that will be running in container
|
||||
```java
|
||||
val server = SameDiffJsonModelServer.<String, Sentiment>builder()
|
||||
.adapter(new StringToSentimentAdapter())
|
||||
.model(mySameDiffModel)
|
||||
.port(8080)
|
||||
.serializer(new SentimentSerializer())
|
||||
.deserializer(new StringDeserializer())
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
server.join();
|
||||
```
|
||||
|
||||
Now, presumably in some other container, we'll set up remote inference client:
|
||||
```java
|
||||
val client = JsonRemoteInference.<String, Sentiment>builder()
|
||||
.endpointAddress("http://youraddress:8080/v1/serving")
|
||||
.serializer(new StringSerializer())
|
||||
.deserializer(new SentimentDeserializer())
|
||||
.build();
|
||||
|
||||
Sentiment result = client.predict(myText);
|
||||
```
|
||||
On top of that, there's async call available, for cases when you need to chain multiple requests to one or multiple remote model servers.
|
||||
|
||||
```java
|
||||
Future<Sentiment> result = client.predictAsync(myText);
|
||||
```
|
|
@ -0,0 +1,179 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<parent>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-remote</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>nd4j-json-server</artifactId>
|
||||
<name>nd4j-json-server</name>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.7</maven.compiler.source>
|
||||
<maven.compiler.target>1.7</maven.compiler.target>
|
||||
<jersey.version>2.26</jersey.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-json-client</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.glassfish.jersey.core</groupId>
|
||||
<artifactId>jersey-client</artifactId>
|
||||
<version>${jersey.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.glassfish.jersey.core</groupId>
|
||||
<artifactId>jersey-server</artifactId>
|
||||
<version>${jersey.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.eclipse.jetty</groupId>
|
||||
<artifactId>jetty-server</artifactId>
|
||||
<version>9.4.19.v20190610</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.eclipse.jetty</groupId>
|
||||
<artifactId>jetty-servlet</artifactId>
|
||||
<version>9.4.19.v20190610</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.glassfish.jersey.inject</groupId>
|
||||
<artifactId>jersey-hk2</artifactId>
|
||||
<version>${jersey.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.glassfish.jersey.media</groupId>
|
||||
<artifactId>jersey-media-json-processing</artifactId>
|
||||
<version>${jersey.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.glassfish.jersey.containers</groupId>
|
||||
<artifactId>jersey-container-servlet-core</artifactId>
|
||||
<version>${jersey.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-core</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>javax.xml.bind</groupId>
|
||||
<artifactId>jaxb-api</artifactId>
|
||||
<version>2.3.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-impl</artifactId>
|
||||
<version>2.3.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-core</artifactId>
|
||||
<version>2.3.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>javax.activation</groupId>
|
||||
<artifactId>activation</artifactId>
|
||||
<version>1.1</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>nd4j-tests-cpu</id>
|
||||
<activation>
|
||||
<activeByDefault>true</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>nd4j-tests-cuda</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>testresources</id>
|
||||
</profile>
|
||||
|
||||
</profiles>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<configuration>
|
||||
<source>${maven.compiler.source}</source>
|
||||
<target>${maven.compiler.target}</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
|
@ -0,0 +1,288 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.eclipse.jetty.server.Server;
|
||||
import org.eclipse.jetty.servlet.ServletContextHandler;
|
||||
import org.nd4j.adapters.InputAdapter;
|
||||
import org.nd4j.adapters.OutputAdapter;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
import org.nd4j.remote.serving.ModelServingServlet;
|
||||
import org.nd4j.remote.serving.SameDiffServlet;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This class provides JSON-powered model serving functionality for SameDiff graphs.
|
||||
* Server url will be http://0.0.0.0:{port}>/v1/serving
|
||||
* Server only accepts POST requests
|
||||
*
|
||||
* @param <I> type of the input class, i.e. String
|
||||
* @param <O> type of the output class, i.e. Sentiment
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Slf4j
|
||||
public class SameDiffJsonModelServer<I, O> {
|
||||
|
||||
|
||||
protected SameDiff sdModel;
|
||||
protected final JsonSerializer<O> serializer;
|
||||
protected final JsonDeserializer<I> deserializer;
|
||||
protected final InferenceAdapter<I, O> inferenceAdapter;
|
||||
protected final int port;
|
||||
|
||||
// this servlet will be used to serve models
|
||||
protected ModelServingServlet<I, O> servingServlet;
|
||||
|
||||
// HTTP server instance
|
||||
protected Server server;
|
||||
|
||||
// for SameDiff only
|
||||
protected String[] orderedInputNodes;
|
||||
protected String[] orderedOutputNodes;
|
||||
|
||||
protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port) {
|
||||
Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535");
|
||||
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
this.serializer = serializer;
|
||||
this.deserializer = deserializer;
|
||||
this.port = port;
|
||||
}
|
||||
|
||||
//@Builder
|
||||
public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) {
|
||||
this(inferenceAdapter, serializer, deserializer, port);
|
||||
this.sdModel = sdModel;
|
||||
this.orderedInputNodes = orderedInputNodes;
|
||||
this.orderedOutputNodes = orderedOutputNodes;
|
||||
|
||||
// TODO: both lists of nodes should be validated, to make sure nodes specified here exist in actual model
|
||||
if (orderedInputNodes != null) {
|
||||
// input nodes list might be null. strange but ok
|
||||
}
|
||||
|
||||
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "SameDiff serving requires at least 1 output node");
|
||||
}
|
||||
|
||||
protected void start(int port, @NonNull ModelServingServlet<I, O> servlet) throws Exception {
|
||||
val context = new ServletContextHandler(ServletContextHandler.SESSIONS);
|
||||
context.setContextPath("/");
|
||||
|
||||
server = new Server(port);
|
||||
server.setHandler(context);
|
||||
|
||||
val jerseyServlet = context.addServlet(org.glassfish.jersey.servlet.ServletContainer.class, "/*");
|
||||
jerseyServlet.setInitOrder(0);
|
||||
jerseyServlet.setServlet(servlet);
|
||||
|
||||
server.start();
|
||||
}
|
||||
|
||||
public void start() throws Exception {
|
||||
Preconditions.checkArgument(sdModel != null, "SameDiff model wasn't defined");
|
||||
|
||||
servingServlet = SameDiffServlet.<I, O>builder()
|
||||
.sdModel(sdModel)
|
||||
.serializer(serializer)
|
||||
.deserializer(deserializer)
|
||||
.inferenceAdapter(inferenceAdapter)
|
||||
.orderedInputNodes(orderedInputNodes)
|
||||
.orderedOutputNodes(orderedOutputNodes)
|
||||
.build();
|
||||
|
||||
start(port, servingServlet);
|
||||
}
|
||||
|
||||
public void join() throws InterruptedException {
|
||||
Preconditions.checkArgument(server != null, "Model server wasn't started yet");
|
||||
|
||||
server.join();
|
||||
}
|
||||
|
||||
public void stop() throws Exception {
|
||||
//Preconditions.checkArgument(server != null, "Model server wasn't started yet");
|
||||
|
||||
server.stop();
|
||||
}
|
||||
|
||||
|
||||
public static class Builder<I,O> {
|
||||
private SameDiff sameDiff;
|
||||
private String[] orderedInputNodes;
|
||||
private String[] orderedOutputNodes;
|
||||
private InferenceAdapter<I, O> inferenceAdapter;
|
||||
private JsonSerializer<O> serializer;
|
||||
private JsonDeserializer<I> deserializer;
|
||||
private int port;
|
||||
|
||||
private InputAdapter<I> inputAdapter;
|
||||
private OutputAdapter<O> outputAdapter;
|
||||
|
||||
public Builder() {}
|
||||
|
||||
public Builder<I,O> sdModel(@NonNull SameDiff sameDiff) {
|
||||
this.sameDiff = sameDiff;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type
|
||||
* @param inferenceAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inferenceAdapter(InferenceAdapter<I,O> inferenceAdapter) {
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify InputAdapter to be used for inference
|
||||
*
|
||||
* PLEASE NOTE: This method is optional, and will require OutputAdapter<O> defined
|
||||
* @param inputAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
|
||||
this.inputAdapter = inputAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows you to specify OutputAdapter to be used for inference
|
||||
*
|
||||
* PLEASE NOTE: This method is optional, and will require InputAdapter<I> defined
|
||||
* @param outputAdapter
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
|
||||
this.outputAdapter = outputAdapter;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines JsonSerializer instance to be used to convert object of output type into JSON format, so it could be sent over the wire
|
||||
*
|
||||
* @param serializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> outputSerializer(@NonNull JsonSerializer<O> serializer) {
|
||||
this.serializer = serializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines JsonDeserializer instance to be used to convert JSON passed through HTTP into actual object of input type, that will be fed into SameDiff model
|
||||
*
|
||||
* @param deserializer
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> inputDeserializer(@NonNull JsonDeserializer<I> deserializer) {
|
||||
this.deserializer = deserializer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines the order of placeholders to be filled with INDArrays provided by Deserializer
|
||||
*
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedInputNodes(String... args) {
|
||||
orderedInputNodes = args;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines the order of placeholders to be filled with INDArrays provided by Deserializer
|
||||
*
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedInputNodes(@NonNull List<String> args) {
|
||||
orderedInputNodes = args.toArray(new String[args.size()]);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedOutputNodes(String... args) {
|
||||
Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element");
|
||||
orderedOutputNodes = args;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input
|
||||
* @param args
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> orderedOutputNodes(@NonNull List<String> args) {
|
||||
Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element");
|
||||
orderedOutputNodes = args.toArray(new String[args.size()]);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method allows to configure HTTP port used for serving
|
||||
*
|
||||
* PLEASE NOTE: port must be free and be in range regular TCP/IP ports range
|
||||
* @param port
|
||||
* @return
|
||||
*/
|
||||
public Builder<I,O> port(int port) {
|
||||
this.port = port;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method builds SameDiffJsonModelServer instance
|
||||
* @return
|
||||
*/
|
||||
public SameDiffJsonModelServer<I, O> build() {
|
||||
if (inferenceAdapter == null) {
|
||||
if (inputAdapter != null && outputAdapter != null) {
|
||||
inferenceAdapter = new InferenceAdapter<I, O>() {
|
||||
@Override
|
||||
public MultiDataSet apply(I input) {
|
||||
return inputAdapter.apply(input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public O apply(INDArray... outputs) {
|
||||
return outputAdapter.apply(outputs);
|
||||
}
|
||||
};
|
||||
} else
|
||||
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
|
||||
}
|
||||
return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.serving;
|
||||
|
||||
import javax.servlet.Servlet;
|
||||
|
||||
/**
|
||||
* This interface describes Servlet interface extension, suited for ND4J/DL4J model serving
|
||||
* @param <I>
|
||||
* @param <O>
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public interface ModelServingServlet<I, O> extends Servlet {
|
||||
//
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.serving;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
|
||||
import javax.servlet.*;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.ws.rs.HttpMethod;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.LinkedHashMap;
|
||||
|
||||
import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
|
||||
|
||||
/**
|
||||
* This servlet provides SameDiff model serving capabilities
|
||||
*
|
||||
* @param <I>
|
||||
* @param <O>
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Slf4j
|
||||
@Builder
|
||||
public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
||||
|
||||
protected SameDiff sdModel;
|
||||
protected JsonSerializer<O> serializer;
|
||||
protected JsonDeserializer<I> deserializer;
|
||||
protected InferenceAdapter<I, O> inferenceAdapter;
|
||||
|
||||
protected String[] orderedInputNodes;
|
||||
protected String[] orderedOutputNodes;
|
||||
|
||||
protected final static String SERVING_ENDPOINT = "/v1/serving";
|
||||
protected final static String LISTING_ENDPOINT = "/v1";
|
||||
protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024; // TODO: should be customizable
|
||||
|
||||
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer){
|
||||
this.serializer = serializer;
|
||||
this.deserializer = deserializer;
|
||||
this.inferenceAdapter = inferenceAdapter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void init(ServletConfig servletConfig) throws ServletException {
|
||||
//
|
||||
}
|
||||
|
||||
@Override
|
||||
public ServletConfig getServletConfig() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws ServletException, IOException {
|
||||
// we'll parse request here, and do model serving
|
||||
val httpRequest = (HttpServletRequest) servletRequest;
|
||||
val httpResponse = (HttpServletResponse) servletResponse;
|
||||
|
||||
if (httpRequest.getMethod().equals(HttpMethod.GET)) {
|
||||
doGet(httpRequest, httpResponse);
|
||||
}
|
||||
else if (httpRequest.getMethod().equals(HttpMethod.POST)) {
|
||||
doPost(httpRequest, httpResponse);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
protected void sendError(String uri, HttpServletResponse response) throws IOException {
|
||||
val msg = "Requested endpoint [" + uri + "] not found";
|
||||
response.setStatus(404, msg);
|
||||
response.sendError(404, msg);
|
||||
}
|
||||
|
||||
protected void sendBadContentType(String actualContentType, HttpServletResponse response) throws IOException {
|
||||
val msg = "Content type [" + actualContentType + "] not supported";
|
||||
response.setStatus(415, msg);
|
||||
response.sendError(415, msg);
|
||||
}
|
||||
|
||||
protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response)
|
||||
throws IOException{
|
||||
val contentType = request.getContentType();
|
||||
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
|
||||
sendBadContentType(contentType, response);
|
||||
int contentLength = request.getContentLength();
|
||||
if (contentLength > PAYLOAD_SIZE_LIMIT) {
|
||||
response.sendError(500, "Payload size limit violated!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException {
|
||||
val processor = new ServingProcessor();
|
||||
String processorReturned = "";
|
||||
String path = request.getPathInfo();
|
||||
if (path.equals(LISTING_ENDPOINT)) {
|
||||
val contentType = request.getContentType();
|
||||
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
|
||||
sendBadContentType(contentType, response);
|
||||
}
|
||||
processorReturned = processor.listEndpoints();
|
||||
}
|
||||
else {
|
||||
sendError(request.getRequestURI(), response);
|
||||
}
|
||||
try {
|
||||
val out = response.getWriter();
|
||||
out.write(processorReturned);
|
||||
} catch (IOException e) {
|
||||
log.error(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
|
||||
val processor = new ServingProcessor();
|
||||
String processorReturned = "";
|
||||
String path = request.getPathInfo();
|
||||
if (path.equals(SERVING_ENDPOINT)) {
|
||||
val contentType = request.getContentType();
|
||||
/*Preconditions.checkArgument(StringUtils.equals(contentType, APPLICATION_JSON),
|
||||
"Content type is " + contentType);*/
|
||||
if (validateRequest(request,response)) {
|
||||
val stream = request.getInputStream();
|
||||
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
|
||||
char[] charBuffer = new char[128];
|
||||
int bytesRead = -1;
|
||||
val buffer = new StringBuilder();
|
||||
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
|
||||
buffer.append(charBuffer, 0, bytesRead);
|
||||
}
|
||||
val requestString = buffer.toString();
|
||||
|
||||
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
|
||||
val map = new LinkedHashMap<String, INDArray>();
|
||||
|
||||
// optionally define placeholders with names provided in server constructor
|
||||
if (orderedInputNodes != null && orderedInputNodes.length > 0) {
|
||||
int cnt = 0;
|
||||
for (val n : orderedInputNodes)
|
||||
map.put(n, mds.getFeatures(cnt++));
|
||||
}
|
||||
|
||||
val output = sdModel.exec(map, orderedOutputNodes);
|
||||
val arrays = new INDArray[output.size()];
|
||||
|
||||
// now we need to get ordered output arrays, as specified in server constructor
|
||||
int cnt = 0;
|
||||
for (val n : orderedOutputNodes)
|
||||
arrays[cnt++] = output.get(n);
|
||||
|
||||
// process result
|
||||
val result = inferenceAdapter.apply(arrays);
|
||||
processorReturned = serializer.serialize(result);
|
||||
}
|
||||
} else {
|
||||
// we return error otherwise
|
||||
sendError(request.getRequestURI(), response);
|
||||
}
|
||||
try {
|
||||
val out = response.getWriter();
|
||||
out.write(processorReturned);
|
||||
} catch (IOException e) {
|
||||
log.error(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getServletInfo() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
//
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
package org.nd4j.remote.serving;
|
||||
|
||||
public class ServingProcessor {
|
||||
|
||||
public String listEndpoints() {
|
||||
String retVal = "/v1/ \n/v1/serving/";
|
||||
return retVal;
|
||||
}
|
||||
|
||||
public String processModel(String body) {
|
||||
String response = null; //"Not implemented";
|
||||
return response;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,271 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.remote.clients.JsonRemoteInference;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.helpers.House;
|
||||
import org.nd4j.remote.helpers.HouseToPredictedPriceAdapter;
|
||||
import org.nd4j.remote.helpers.PredictedPrice;
|
||||
import org.nd4j.remote.clients.serde.impl.FloatArraySerde;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
public class SameDiffJsonModelServerTest {
|
||||
|
||||
@Test
|
||||
public void basicServingTest_1() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.sdModel(sd)
|
||||
.port(18080)
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:18080/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
price = client.predict(house);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
assertNotNull(price);
|
||||
assertEquals((float) district + 1.0f, price.getPrice(), 1e-5);
|
||||
|
||||
server.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserialization_1() {
|
||||
String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}";
|
||||
val deserializer = new House.HouseDeserializer();
|
||||
val result = deserializer.deserialize(request);
|
||||
assertEquals(2, result.getDistrict());
|
||||
assertEquals(100, result.getArea());
|
||||
assertEquals(2, result.getBathrooms());
|
||||
assertEquals(3, result.getBedrooms());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserialization_2() {
|
||||
String request = "{\"price\":1}";
|
||||
val deserializer = new PredictedPrice.PredictedPriceDeserializer();
|
||||
val result = deserializer.deserialize(request);
|
||||
assertEquals(1.0, result.getPrice(), 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeserialization_3() {
|
||||
float[] data = {0.0f, 0.1f, 0.2f};
|
||||
val serialized = new FloatArraySerde().serialize(data);
|
||||
val deserialized = new FloatArraySerde().deserialize(serialized);
|
||||
assertArrayEquals(data, deserialized, 1e-5f);
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void negativeServingTest_1() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(null)
|
||||
.sdModel(sd)
|
||||
.port(18080)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void negativeServingTest_2() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.sdModel(sd)
|
||||
.port(18080)
|
||||
.build();
|
||||
|
||||
}
|
||||
|
||||
@Test(expected = IOException.class)
|
||||
public void negativeServingTest_3() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.sdModel(sd)
|
||||
.port(18080)
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
|
||||
@Override
|
||||
public PredictedPrice deserialize(String json) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.endpointAddress("http://localhost:18080/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
// warmup
|
||||
PredictedPrice price = client.predict(house);
|
||||
|
||||
server.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void asyncServingTest() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.sdModel(sd)
|
||||
.port(18080)
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
|
||||
.endpointAddress("http://localhost:18080/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
Future<PredictedPrice> price = client.predictAsync(house);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
|
||||
|
||||
server.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void negativeAsyncTest() throws Exception {
|
||||
val sd = SameDiff.create();
|
||||
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
|
||||
val result = sdVariable.add(1.0);
|
||||
val total = result.mean("total", Integer.MAX_VALUE);
|
||||
|
||||
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
|
||||
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
|
||||
.inputDeserializer(new House.HouseDeserializer())
|
||||
.inferenceAdapter(new HouseToPredictedPriceAdapter())
|
||||
.orderedInputNodes(new String[]{"input"})
|
||||
.orderedOutputNodes(new String[]{"total"})
|
||||
.sdModel(sd)
|
||||
.port(18080)
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
|
||||
// Fake deserializer to test failure
|
||||
val client = JsonRemoteInference.<House, PredictedPrice>builder()
|
||||
.inputSerializer(new House.HouseSerializer())
|
||||
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
|
||||
@Override
|
||||
public PredictedPrice deserialize(String json) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.endpointAddress("http://localhost:18080/v1/serving")
|
||||
.build();
|
||||
|
||||
int district = 2;
|
||||
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
|
||||
|
||||
val timeStart = System.currentTimeMillis();
|
||||
try {
|
||||
Future<PredictedPrice> price = client.predictAsync(house);
|
||||
assertNotNull(price);
|
||||
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
|
||||
val timeStop = System.currentTimeMillis();
|
||||
|
||||
log.info("Time spent: {} ms", timeStop - timeStart);
|
||||
} catch (ExecutionException e) {
|
||||
assertTrue(e.getMessage().contains("Deserialization failed"));
|
||||
}
|
||||
|
||||
server.stop();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package org.nd4j.remote;
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.http.client.methods.HttpGet;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.impl.client.HttpClientBuilder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class SameDiffServletTest {
|
||||
|
||||
private SameDiffJsonModelServer server;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
server = new SameDiffJsonModelServer.Builder<String, String>()
|
||||
.sdModel(SameDiff.create())
|
||||
.port(8080)
|
||||
.inferenceAdapter(new InferenceAdapter<String, String>() {
|
||||
@Override
|
||||
public MultiDataSet apply(String input) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String apply(INDArray... nnOutput) {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.outputSerializer(new JsonSerializer<String>() {
|
||||
@Override
|
||||
public String serialize(String o) {
|
||||
return "";
|
||||
}
|
||||
})
|
||||
.inputDeserializer(new JsonDeserializer<String>() {
|
||||
@Override
|
||||
public String deserialize(String json) {
|
||||
return "";
|
||||
}
|
||||
})
|
||||
.orderedOutputNodes(new String[]{"output"})
|
||||
.build();
|
||||
|
||||
server.start();
|
||||
//server.join();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
server.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getEndpoints() throws IOException {
|
||||
val request = new HttpGet( "http://localhost:8080/v1" );
|
||||
request.setHeader("Content-type", "application/json");
|
||||
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(200, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testContentTypeGet() throws IOException {
|
||||
val request = new HttpGet( "http://localhost:8080/v1" );
|
||||
request.setHeader("Content-type", "text/plain");
|
||||
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(415, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testContentTypePost() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving");
|
||||
request.setHeader("Content-type", "text/plain");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(415, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void postForServing() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving");
|
||||
request.setHeader("Content-type", "application/json");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(500, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotFoundPost() throws Exception {
|
||||
val request = new HttpPost("http://localhost:8080/v1/serving/some");
|
||||
request.setHeader("Content-type", "application/json");
|
||||
val response = HttpClientBuilder.create().build().execute( request );
|
||||
assertEquals(404, response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotFoundGet() throws Exception {
|
||||
val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" );
|
||||
requestGet.setHeader("Content-type", "application/json");
|
||||
|
||||
val responseGet = HttpClientBuilder.create().build().execute( requestGet );
|
||||
assertEquals(404, responseGet.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.helpers;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import lombok.*;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class House {
|
||||
private int district;
|
||||
private int bedrooms;
|
||||
private int bathrooms;
|
||||
private int area;
|
||||
|
||||
|
||||
public static class HouseSerializer implements JsonSerializer<House> {
|
||||
@Override
|
||||
public String serialize(@NonNull House o) {
|
||||
return new Gson().toJson(o);
|
||||
}
|
||||
}
|
||||
|
||||
public static class HouseDeserializer implements JsonDeserializer<House> {
|
||||
@Override
|
||||
public House deserialize(@NonNull String json) {
|
||||
return new Gson().fromJson(json, House.class);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.helpers;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.adapters.InferenceAdapter;
|
||||
|
||||
@Slf4j
|
||||
public class HouseToPredictedPriceAdapter implements InferenceAdapter<House, PredictedPrice> {
|
||||
|
||||
@Override
|
||||
public MultiDataSet apply(@NonNull House input) {
|
||||
// we just create vector array with shape[4] and assign it's value to the district value
|
||||
return new MultiDataSet(Nd4j.create(DataType.FLOAT, 4).assign(input.getDistrict()), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PredictedPrice apply(INDArray... nnOutput) {
|
||||
return new PredictedPrice(nnOutput[0].getFloat(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.remote.helpers;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.remote.clients.serde.JsonDeserializer;
|
||||
import org.nd4j.remote.clients.serde.JsonSerializer;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class PredictedPrice {
|
||||
private float price;
|
||||
|
||||
public static class PredictedPriceSerializer implements JsonSerializer<PredictedPrice> {
|
||||
@Override
|
||||
public String serialize(@NonNull PredictedPrice o) {
|
||||
return new Gson().toJson(o);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PredictedPriceDeserializer implements JsonDeserializer<PredictedPrice> {
|
||||
@Override
|
||||
public PredictedPrice deserialize(@NonNull String json) {
|
||||
return new Gson().fromJson(json, PredictedPrice.class);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
package org.nd4j.remote.serde;
|
||||
|
||||
import lombok.val;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.remote.clients.serde.impl.*;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class BasicSerdeTests {
|
||||
private final static DoubleArraySerde doubleArraySerde = new DoubleArraySerde();
|
||||
private final static FloatArraySerde floatArraySerde = new FloatArraySerde();
|
||||
private final static StringSerde stringSerde = new StringSerde();
|
||||
private final static IntegerSerde integerSerde = new IntegerSerde();
|
||||
private final static FloatSerde floatSerde = new FloatSerde();
|
||||
private final static DoubleSerde doubleSerde = new DoubleSerde();
|
||||
private final static BooleanSerde booleanSerde = new BooleanSerde();
|
||||
|
||||
@Test
|
||||
public void testStringSerde_1() {
|
||||
val jvmString = "String with { strange } elements";
|
||||
|
||||
val serialized = stringSerde.serialize(jvmString);
|
||||
val deserialized = stringSerde.deserialize(serialized);
|
||||
|
||||
assertEquals(jvmString, deserialized);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFloatArraySerDe_1() {
|
||||
val jvmArray = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
|
||||
|
||||
val serialized = floatArraySerde.serialize(jvmArray);
|
||||
val deserialized = floatArraySerde.deserialize(serialized);
|
||||
|
||||
assertArrayEquals(jvmArray, deserialized, 1e-5f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleArraySerDe_1() {
|
||||
val jvmArray = new double[] {1.0, 2.0, 3.0, 4.0, 5.0};
|
||||
|
||||
val serialized = doubleArraySerde.serialize(jvmArray);
|
||||
val deserialized = doubleArraySerde.deserialize(serialized);
|
||||
|
||||
assertArrayEquals(jvmArray, deserialized, 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFloatSerde_1() {
|
||||
val f = 119.f;
|
||||
|
||||
val serialized = floatSerde.serialize(f);
|
||||
val deserialized = floatSerde.deserialize(serialized);
|
||||
|
||||
assertEquals(f, deserialized, 1e-5f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDoubleSerde_1() {
|
||||
val d = 119.;
|
||||
|
||||
val serialized = doubleSerde.serialize(d);
|
||||
val deserialized = doubleSerde.deserialize(serialized);
|
||||
|
||||
assertEquals(d, deserialized, 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIntegerSerde_1() {
|
||||
val f = 119;
|
||||
|
||||
val serialized = integerSerde.serialize(f);
|
||||
val deserialized = integerSerde.deserialize(serialized);
|
||||
|
||||
|
||||
assertEquals(f, deserialized.intValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBooleanSerde_1() {
|
||||
val f = true;
|
||||
|
||||
val serialized = booleanSerde.serialize(f);
|
||||
val deserialized = booleanSerde.deserialize(serialized);
|
||||
|
||||
|
||||
assertEquals(f, deserialized);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
||||
~
|
||||
~ This program and the accompanying materials are made available under the
|
||||
~ terms of the Apache License, Version 2.0 which is available at
|
||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ License for the specific language governing permissions and limitations
|
||||
~ under the License.
|
||||
~
|
||||
~ SPDX-License-Identifier: Apache-2.0
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
||||
|
||||
<configuration>
|
||||
|
||||
|
||||
|
||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||
<file>logs/application.log</file>
|
||||
<encoder>
|
||||
<pattern>%date - [%level] - from %logger in %thread
|
||||
%n%message%n%xException%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<logger name="org.eclipse.jetty" level="WARN" />
|
||||
<logger name="org.apache.catalina.core" level="WARN" />
|
||||
<logger name="org.springframework" level="WARN" />
|
||||
<logger name="org.nd4j" level="DEBUG" />
|
||||
<logger name="org.deeplearning4j" level="INFO" />
|
||||
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
|
||||
</configuration>
|
|
@ -0,0 +1,35 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<modules>
|
||||
<module>nd4j-json-client</module>
|
||||
<module>nd4j-grpc-client</module>
|
||||
<module>nd4j-json-server</module>
|
||||
</modules>
|
||||
|
||||
<parent>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>nd4j-remote</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<name>nd4j-remote</name>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.7</maven.compiler.source>
|
||||
<maven.compiler.target>1.7</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>testresources</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -33,7 +33,6 @@
|
|||
<module>nd4j-camel-routes</module>
|
||||
<module>nd4j-gson</module>
|
||||
<module>nd4j-arrow</module>
|
||||
<module>nd4j-grpc</module>
|
||||
</modules>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -62,6 +62,7 @@
|
|||
<module>nd4j-parameter-server-parent</module>
|
||||
<module>nd4j-uberjar</module>
|
||||
<module>nd4j-tensorflow</module>
|
||||
<module>nd4j-remote</module>
|
||||
</modules>
|
||||
|
||||
<dependencyManagement>
|
||||
|
|
Loading…
Reference in New Issue