[WIP] Remote inference (#96)

* fix pad javadoc and @see links. (#72)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* [WIP] More fixes (#73)

* special tests for ConstantTadHelper/ConstantShapeHelper

Signed-off-by: raver119 <raver119@gmail.com>

* release methods for data buffers

Signed-off-by: raver119 <raver119@gmail.com>

* delete temporary buffer Java side

Signed-off-by: raver119 <raver119@gmail.com>

* delete temporary buffer Java side

Signed-off-by: raver119 <raver119@gmail.com>

* delete temporary TadPack C++/Java side (#74)

Signed-off-by: raver119 <raver119@gmail.com>

* Zoo model TF import test updates (#75)

* argLine fix, update compression_gru comment

* updated comment for xception

* undid but commented argLine change

* updated xlnet comment

* copyright headers

* - new NDArray methods like()/ulike() (#77)

- fix for depthwise_conv2d_bp + special test

Signed-off-by: raver119 <raver119@gmail.com>

* upsampling2d fix CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* DL4J trace logging (#79)

* MLN/CG trace logging for debugging

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tiny tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* strided_slice_bp shape fn leak fix

Signed-off-by: raver119 <raver119@gmail.com>

* SameDiff fixes and naming (#78)

* remove SDVariable inplace methods

* import methods

* npe fix in OpVal

* removed SameDiff inplace ops from tests

* Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything

* quick fixes

* javadoc

* SDVariable eval with placeholders

* use regex match

* better matching

* fix javadoc. (#76)

* fix javadoc.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* replace most @see with @link s.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* 4 additional tests

Signed-off-by: raver119 <raver119@gmail.com>

* Various DL4J/ND4J fixes (#81)

* #7954 Force refresh of UI when switching tabs on overview page

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8017 Concurrent modification exception (synchronize) fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8033 Don't initialize updater in middle of writing memory crash dump

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8208 Fix shape checks for ND4J int[] creator methods

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #6385 #7992 Keras import naming fixes + cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8016 Upsampling3D - add NDHWC format support

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Refactor NativeOps.h to export C functions

* Actually export functions from NativeOps.h

* Adapt the Java wrappers in ND4J generated with JavaCPP

* Create C wrappers for some of the C++ classes currently used by ND4J

* remove duplicate code in createBufferDetached. (#83)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Keras model import - updater lr fix (#84)

* Keras model import - updater lr fix

Signed-off-by: eraly <susan.eraly@gmail.com>

* Keras model import - updater lr fix, cleanup

Signed-off-by: eraly <susan.eraly@gmail.com>

* Fix functions of OpaqueVariablesSet

* SameDiff Convolution Config validation, better output methods (#82)

* Conv Config validation & tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* stackOutputs utility method

Signed-off-by: Ryan Nett <rnett@skymind.io>

* use constructor for validation, support negative kernel sizes (infered from weights)

Signed-off-by: Ryan Nett <rnett@skymind.io>

* better output methods

Signed-off-by: Ryan Nett <rnett@skymind.io>

* move output to be with fit and evaluate

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* refactor duplicate code from pad methods. (#86)

* refactor duplicate code from pad methods.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* replace switch with if.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Various ND4J/DL4J fixes and improvements (#87)

* Reshape and reallocate - small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Reshape and reallocate - small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #6488 ElementWiseVertex broadcast support

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Constructors and broadcast supported it Transforms.max/min

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8054 ElementWiseVertex now supports broadcast inputs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8057 Nd4j.create overload dtype fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7551 ND4J Shape validation fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] Numpy boolean import (#91)

* numpy bool type

Signed-off-by: raver119 <raver119@gmail.com>

* numpy bool java side

Signed-off-by: raver119 <raver119@gmail.com>

* remove create method with unused parameter. (#89)

* remove create method with unused parameter.

* removed more unused methods.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* removing more unused code.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* last removal of unused code.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* remove createSparse methods. (#92)

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* Various ND4J/DL4J fixes (#90)

* Deprecate Old*Op instances

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8063 #8054 Broadcast exceptions + cleanup inplace ops

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove bad test condition

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7993 Fix shape function issue in crop_and_resize op

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J SameDiff lambda layer fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8029 Fix for pnorm backprop math

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8038 Fix Op profiler NaN/Inf triggering + add tests (#93)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* createUninitializedDetached refactoring. (#94)

* wip

* update interface, add null implementations.

* Breaking one test in a weird way.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* createUninitializedDetached refactored.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* cuda build fix for issues introduced by recent refactoring

Signed-off-by: raver119 <raver119@gmail.com>

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* deps tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* initial prototype

Signed-off-by: raver119 <raver119@gmail.com>

* modules reorganized

Signed-off-by: raver119 <raver119@gmail.com>

* gprc module moved to nd4j-remote as well

Signed-off-by: raver119 <raver119@gmail.com>

* gprc module moved to nd4j-remote as well

Signed-off-by: raver119 <raver119@gmail.com>

* serving prototype

Signed-off-by: raver119 <raver119@gmail.com>

* serving prototype

Signed-off-by: raver119 <raver119@gmail.com>

* serving prototype

Signed-off-by: raver119 <raver119@gmail.com>

* serving prototype

Signed-off-by: raver119 <raver119@gmail.com>

* [WIP] More of CUDA (#95)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* Implementation of hashcode cuda helper. Working edition.

* Fixed parallel test input arangements.

* Fixed tests for hashcode op.

* Fixed shape calculation for image:crop_and_resize op and test.

* NativeOps tests. Initial test suite.

* Added tests for indexReduce methods.

* Added test on execBroadcast with NDArray as dimensions.

* Added test on execBroadcastBool with NDArray as dimensions.

* Added tests on execPairwiseTransform and execPairwiseTransofrmBool.

* Added tests for execReduce with scalar results.

* Added reduce tests for non-empty dims array.

* Added tests for reduce3.

* Added tests for execScalar.

* Added tests for execSummaryStats.

* - provide cpu/cuda code for batch_to_space
- testing it

Signed-off-by: Yurii <yurii@skymind.io>

* - remove old test for batch_to_space (had wrong format and numbers were not checked)

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed complilation errors with test.

* Added test for execTransformFloat.

* Added test for execTransformSame.

* Added test for execTransformBool.

* Added test for execTransformStrict.

* Added tests for execScalar/execScalarBool with TADs.

* Added test for flatten.

* - provide cpu/cuda code for space_to_Batch operaion

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for concat.

* comment unnecessary stuff in s_t_b

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for specialConcat.

* Added tests for memcpy/set routines.

* Fixed pullRow cuda test.

* Added pullRow test.

* Added average test.

* - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...)

Signed-off-by: Yurii <yurii@skymind.io>

* - debugging and fixing cuda tests in JavaInteropTests file

Signed-off-by: Yurii <yurii@skymind.io>

* - correct some tests

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for shuffle.

* Fixed ops declarations.

* Restored omp and added shuffle test.

* Added convertTypes test.

* Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps.

* Added sort tests.

* Added tests for execCustomOp.

* - further debuging and fixing tests terminated with crash

Signed-off-by: Yurii <yurii@skymind.io>

* Added tests for calculateOutputShapes.

* Addded Benchmarks test.

* Commented benchmark tests.

* change assertion

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for apply_sgd op. Added cpu helper for that op.

* Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps.

* Added test for assign broadcastable.

* Added tests for assign_bp op.

* Added tests for axpy op.

* - assign/execScalar/execTransformAny signature change
- minor test fix

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed axpy op.

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* - fix tests for nativeOps::concat

Signed-off-by: Yurii <yurii@skymind.io>

* sequential transform/scalar

Signed-off-by: raver119 <raver119@gmail.com>

* allow nested parallelism

Signed-off-by: raver119 <raver119@gmail.com>

* assign_bp leak fix

Signed-off-by: raver119 <raver119@gmail.com>

* block setRNG fix

Signed-off-by: raver119 <raver119@gmail.com>

* enable parallelism by default

Signed-off-by: raver119 <raver119@gmail.com>

* enable nested parallelism by default

Signed-off-by: raver119 <raver119@gmail.com>

* Added cuda implementation for row_count helper.

* Added implementation for tnse gains op helper.

* - take into account possible situations when input arrays are empty in reduce_ cuda stuff

Signed-off-by: Yurii <yurii@skymind.io>

* Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces.

* Added kernel for tsne/symmetrized op heleper.

* Implementation of tsne/symmetrized op cuda helper. Working edition.

* Eliminated waste printfs.

* Added test for broadcastgradientargs op.

* host-only fallback for empty reduce float

Signed-off-by: raver119 <raver119@gmail.com>

* - some tests fixes

Signed-off-by: Yurii <yurii@skymind.io>

* - correct the rest of reduce_ stuff

Signed-off-by: Yurii <yurii@skymind.io>

* - further correction of reduce_ stuff

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for Cbow op. Also added cuda implementation for cbow helpers.

* - improve code of stack operation for scalar case

Signed-off-by: Yurii <yurii@skymind.io>

* - provide cuda kernel for gatherND operation

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of cbow helpers with cuda kernels.

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* - further correction of cuda stuff

Signed-off-by: Yurii <yurii@skymind.io>

* Implementatation of cbow op helper with cuda kernels. Working edition.

* Skip random testing for cudablas case.

* lstmBlockCell context fix

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for ELU and ELU_BP ops.

* Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops.

* Added tests for neq_scalar.

* Added test for noop.

* - further work on clipbynorm_bp

Signed-off-by: Yurii <yurii@skymind.io>

* - get rid of concat op call, use instead direct concat helper call

Signed-off-by: Yurii <yurii@skymind.io>

* lstmBlockCell context fix

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for lrelu and lrelu_bp.

* Added tests for selu and selu_bp.

* Fixed lrelu derivative helpers.

* - some corrections in lstm

Signed-off-by: Yurii <yurii@skymind.io>

* operator * result shape fix

Signed-off-by: raver119 <raver119@gmail.com>

* - correct typo in lstmCell

Signed-off-by: Yurii <yurii@skymind.io>

* few tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA inverse broadcast bool fix

Signed-off-by: raver119 <raver119@gmail.com>

* disable MMAP test for CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* BooleanOp syncToDevice

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* additional data types for im2col/col2im

Signed-off-by: raver119 <raver119@gmail.com>

* Added test for firas_sparse op.

* one more RandomBuffer test excluded

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for flatten op.

* Added test for Floor op.

* bunch of tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* mmulDot tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* Implemented floordiv_bp op and tests.

* Fixed scalar case with cuda implementation for bds.

* - work on cuda kernel for clip_by_norm backprop op is completed

Signed-off-by: Yurii <yurii@skymind.io>

* Eliminate cbow crach.

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* Eliminated abortion with batched nlp test.

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed shared flag initializing.

* disabled bunch of cpu workspaces tests

Signed-off-by: raver119 <raver119@gmail.com>

* scalar operators fix: missing registerSpecialUse call

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed logdet for cuda and tests.

* - correct clipBynorm_bp

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed crop_and_resize shape datatype.

* - correct some mmul tests

Signed-off-by: Yurii <yurii@skymind.io>

* build fix

Signed-off-by: raver119 <raver119@gmail.com>

* exclude two methods for JNI

Signed-off-by: raver119 <raver119@gmail.com>

* exclude two methods for JNI

Signed-off-by: raver119 <raver119@gmail.com>

* exclude two methods for JNI (#97)

Signed-off-by: raver119 <raver119@gmail.com>

* temporary stack fix

Signed-off-by: raver119 <raver119@gmail.com>

* downgrade jetty to latest stable version

Signed-off-by: raver119 <raver119@gmail.com>

* test and profiles

Signed-off-by: raver119 <raver119@gmail.com>

* Servlet skeleton

* one test case

Signed-off-by: raver119 <raver119@gmail.com>

* one test case

Signed-off-by: raver119 <raver119@gmail.com>

* compilation fix

Signed-off-by: raver119 <raver119@gmail.com>

* draft improvements

Signed-off-by: raver119 <raver119@gmail.com>

* draft improvements

Signed-off-by: raver119 <raver119@gmail.com>

* proof of concept works

Signed-off-by: raver119 <raver119@gmail.com>

* proof of concept works

Signed-off-by: raver119 <raver119@gmail.com>

* Servlet

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* logging + simple timing

Signed-off-by: raver119 <raver119@gmail.com>

* Content type fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Profile required

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Servlet tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Post test

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Tests added:

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Minor tweaks

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Constants used

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Check content type

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Some tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Errors checking

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Constraints and tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Minor tweaks

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Dl4j servlet skeleton

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Moving class to dl4j

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Builder extended

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* initial dl4j commit

Signed-off-by: raver119 <raver119@gmail.com>

* unirest version change

Signed-off-by: raver119 <raver119@gmail.com>

* temp fallback

Signed-off-by: raver119 <raver119@gmail.com>

* Reverted unirest version

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Reverted unirest version

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* revert back unirest version change

Signed-off-by: raver119 <raver119@gmail.com>

* revert unirest change

Signed-off-by: raver119 <raver119@gmail.com>

* some additional checks in builder

Signed-off-by: raver119 <raver119@gmail.com>

* few more fields

Signed-off-by: raver119 <raver119@gmail.com>

* Test added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* lombok

Signed-off-by: raver119 <raver119@gmail.com>

* Tests added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* deps

Signed-off-by: raver119 <raver119@gmail.com>

* profiles re-introduced

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Model servlet

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* builders

Signed-off-by: raver119 <raver119@gmail.com>

* builders

Signed-off-by: raver119 <raver119@gmail.com>

* Servlet skeleton

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Servlet tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* builders

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of old class

Signed-off-by: raver119 <raver119@gmail.com>

* use PI for inference

Signed-off-by: raver119 <raver119@gmail.com>

* superbuilder

Signed-off-by: raver119 <raver119@gmail.com>

* get back builder

Signed-off-by: raver119 <raver119@gmail.com>

* Servlet builder

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* PI setup

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of superbuilder

Signed-off-by: raver119 <raver119@gmail.com>

* SameDiffServlet inheritance constructor

Signed-off-by: raver119 <raver119@gmail.com>

* dl4jservlet attached to samediffservlet

Signed-off-by: raver119 <raver119@gmail.com>

* builder types fix

Signed-off-by: raver119 <raver119@gmail.com>

* dummy model

Signed-off-by: raver119 <raver119@gmail.com>

* single out

Signed-off-by: raver119 <raver119@gmail.com>

* loss

Signed-off-by: raver119 <raver119@gmail.com>

* Tests added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* missed builder type

Signed-off-by: raver119 <raver119@gmail.com>

* working serving example

Signed-off-by: raver119 <raver119@gmail.com>

* sd model fix

Signed-off-by: raver119 <raver119@gmail.com>

* fix unirest version

Signed-off-by: raver119 <raver119@gmail.com>

* More tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tests added:

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Minor tests fixes

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Tests fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Build fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Test added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Tests fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Ser/deser added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* one more unirest fix

Signed-off-by: raver119 <raver119@gmail.com>

* Custom serializers

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Tests disabled

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* revert back unirest version change

Signed-off-by: raver119 <raver119@gmail.com>

* update

Signed-off-by: raver119 <raver119@gmail.com>

* some default fields values

Signed-off-by: raver119 <raver119@gmail.com>

* some comments/javadoc

Signed-off-by: raver119 <raver119@gmail.com>

* - move serde impls to client module
- get rid of INDArray serde for now

Signed-off-by: raver119 <raver119@gmail.com>

* jackson-based serde for float[], double[] and String

Signed-off-by: raver119 <raver119@gmail.com>

* more of basic ser/de + tests

Signed-off-by: raver119 <raver119@gmail.com>

* minor api changes

Signed-off-by: raver119 <raver119@gmail.com>

* change imports/signatures

Signed-off-by: raver119 <raver119@gmail.com>

* Optional parralel inference

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Insert pause between tests as workaround for unavailable port issue

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* few unused imports removed

Signed-off-by: raver119 <raver119@gmail.com>

* Models usage

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Models usage

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* - InputAdapter + OutputAdapter = InferenceAdapter
- JsonModelServer now allows separate configuration of InputAdapter and OutputAdapter

Signed-off-by: raver119 <raver119@gmail.com>

* unused import

Signed-off-by: raver119 <raver119@gmail.com>

* input adapter..

Signed-off-by: raver119 <raver119@gmail.com>

* minor signature change

Signed-off-by: raver119 <raver119@gmail.com>

* few more signatures updated

Signed-off-by: raver119 <raver119@gmail.com>

* input/output adapter

Signed-off-by: raver119 <raver119@gmail.com>

* Tests added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* javadocs added

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* Test fixed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* minor polishing

Signed-off-by: raver119 <raver119@gmail.com>

* more of javadoc

Signed-off-by: raver119 <raver119@gmail.com>

* signature change

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-14 12:11:09 +03:00 committed by GitHub
parent b10ab239c0
commit ec847e034b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 3974 additions and 40 deletions

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform.client; package org.datavec.spark.transform.client;
import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.http.exceptions.UnirestException; import com.mashape.unirest.http.exceptions.UnirestException;

View File

@ -51,11 +51,13 @@
<artifactId>datavec-spark-inference-model</artifactId> <artifactId>datavec-spark-inference-model</artifactId>
<version>${datavec.version}</version> <version>${datavec.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-spark_2.11</artifactId> <artifactId>datavec-spark_2.11</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId> <artifactId>datavec-data-image</artifactId>
@ -67,61 +69,73 @@
<artifactId>akka-cluster_2.11</artifactId> <artifactId>akka-cluster_2.11</artifactId>
<version>${akka.version}</version> <version>${akka.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>joda-time</groupId> <groupId>joda-time</groupId>
<artifactId>joda-time</artifactId> <artifactId>joda-time</artifactId>
<version>${jodatime.version}</version> <version>${jodatime.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId> <artifactId>commons-lang3</artifactId>
<version>${commons-lang3.version}</version> <version>${commons-lang3.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.hibernate</groupId> <groupId>org.hibernate</groupId>
<artifactId>hibernate-validator</artifactId> <artifactId>hibernate-validator</artifactId>
<version>${hibernate.version}</version> <version>${hibernate.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.scala-lang</groupId> <groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId> <artifactId>scala-library</artifactId>
<version>${scala.version}</version> <version>${scala.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.scala-lang</groupId> <groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId> <artifactId>scala-reflect</artifactId>
<version>${scala.version}</version> <version>${scala.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.yaml</groupId> <groupId>org.yaml</groupId>
<artifactId>snakeyaml</artifactId> <artifactId>snakeyaml</artifactId>
<version>${snakeyaml.version}</version> <version>${snakeyaml.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId> <artifactId>jackson-core</artifactId>
<version>${jackson.version}</version> <version>${jackson.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId> <artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version> <version>${jackson.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.fasterxml.jackson.core</groupId> <groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId> <artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version> <version>${jackson.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.fasterxml.jackson.datatype</groupId> <groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId> <artifactId>jackson-datatype-jdk8</artifactId>
<version>${jackson.version}</version> <version>${jackson.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.fasterxml.jackson.datatype</groupId> <groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId> <artifactId>jackson-datatype-jsr310</artifactId>
<version>${jackson.version}</version> <version>${jackson.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-java_2.11</artifactId> <artifactId>play-java_2.11</artifactId>
@ -137,39 +151,44 @@
</exclusion> </exclusion>
</exclusions> </exclusions>
</dependency> </dependency>
<dependency> <dependency>
<groupId>net.jodah</groupId> <groupId>net.jodah</groupId>
<artifactId>typetools</artifactId> <artifactId>typetools</artifactId>
<version>${jodah.typetools.version}</version> <version>${jodah.typetools.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-json_2.11</artifactId> <artifactId>play-json_2.11</artifactId>
<version>${play.version}</version> <version>${play.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-server_2.11</artifactId> <artifactId>play-server_2.11</artifactId>
<version>${play.version}</version> <version>${play.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play_2.11</artifactId> <artifactId>play_2.11</artifactId>
<version>${play.version}</version> <version>${play.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-netty-server_2.11</artifactId> <artifactId>play-netty-server_2.11</artifactId>
<version>${play.version}</version> <version>${play.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.mashape.unirest</groupId> <groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId> <artifactId>unirest-java</artifactId>
<version>${unirest.version}</version> <version>${unirest.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.beust</groupId> <groupId>com.beust</groupId>
<artifactId>jcommander</artifactId> <artifactId>jcommander</artifactId>

View File

@ -52,6 +52,7 @@ public class CSVSparkTransformServerNoJsonTest {
public static void before() throws Exception { public static void before() throws Exception {
server = new CSVSparkTransformServer(); server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson()); FileUtils.write(fileSave, transformProcess.toJson());
// Only one time // Only one time
Unirest.setObjectMapper(new ObjectMapper() { Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
@ -73,6 +74,7 @@ public class CSVSparkTransformServerNoJsonTest {
} }
} }
}); });
server.runMain(new String[] {"-dp", "9050"}); server.runMain(new String[] {"-dp", "9050"});
} }

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform; package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.Unirest;
@ -49,6 +50,7 @@ public class CSVSparkTransformServerTest {
server = new CSVSparkTransformServer(); server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson()); FileUtils.write(fileSave, transformProcess.toJson());
// Only one time // Only one time
Unirest.setObjectMapper(new ObjectMapper() { Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper(); new org.nd4j.shade.jackson.databind.ObjectMapper();
@ -69,6 +71,7 @@ public class CSVSparkTransformServerTest {
} }
} }
}); });
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"}); server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
} }

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform; package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.Unirest;

View File

@ -16,6 +16,7 @@
package org.datavec.spark.transform; package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode; import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.Unirest;

View File

@ -19,10 +19,10 @@ package org.deeplearning4j.nearestneighbor.client;
import com.mashape.unirest.http.ObjectMapper; import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.request.HttpRequest; import com.mashape.unirest.request.HttpRequest;
import com.mashape.unirest.request.HttpRequestWithBody;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.val;
import org.deeplearning4j.nearestneighbor.model.*; import org.deeplearning4j.nearestneighbor.model.*;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.base64.Nd4jBase64;
@ -51,6 +51,7 @@ public class NearestNeighborsClient {
static { static {
// Only one time // Only one time
Unirest.setObjectMapper(new ObjectMapper() { Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper(); new org.nd4j.shade.jackson.databind.ObjectMapper();
@ -89,7 +90,7 @@ public class NearestNeighborsClient {
NearestNeighborRequest request = new NearestNeighborRequest(); NearestNeighborRequest request = new NearestNeighborRequest();
request.setInputIndex(index); request.setInputIndex(index);
request.setK(k); request.setK(k);
HttpRequestWithBody req = Unirest.post(url + "/knn"); val req = Unirest.post(url + "/knn");
req.header("accept", "application/json") req.header("accept", "application/json")
.header("Content-Type", "application/json").body(request); .header("Content-Type", "application/json").body(request);
addAuthHeader(req); addAuthHeader(req);
@ -112,7 +113,7 @@ public class NearestNeighborsClient {
Base64NDArrayBody base64NDArrayBody = Base64NDArrayBody base64NDArrayBody =
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build(); Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
HttpRequestWithBody req = Unirest.post(url + "/knnnew"); val req = Unirest.post(url + "/knnnew");
req.header("accept", "application/json") req.header("accept", "application/json")
.header("Content-Type", "application/json").body(base64NDArrayBody); .header("Content-Type", "application/json").body(base64NDArrayBody);
addAuthHeader(req); addAuthHeader(req);

View File

@ -19,7 +19,6 @@ package org.deeplearning4j.models.word2vec;
import com.google.gson.JsonArray; import com.google.gson.JsonArray;
import com.google.gson.JsonObject; import com.google.gson.JsonObject;
import com.google.gson.JsonParser; import com.google.gson.JsonParser;
import jdk.nashorn.internal.objects.annotations.Property;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import org.apache.commons.compress.compressors.gzip.GzipUtils; import org.apache.commons.compress.compressors.gzip.GzipUtils;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.nn.adapters; package org.deeplearning4j.nn.adapters;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.OutputAdapter; import org.nd4j.adapters.OutputAdapter;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;

View File

@ -18,7 +18,7 @@ package org.deeplearning4j.nn.adapters;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.api.OutputAdapter; import org.nd4j.adapters.OutputAdapter;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**

View File

@ -24,6 +24,7 @@ import lombok.val;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;

View File

@ -25,6 +25,7 @@ import lombok.val;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;; import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.RegressionEvaluation; import org.deeplearning4j.eval.RegressionEvaluation;

View File

@ -0,0 +1,110 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>jar</packaging>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-remote</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-json-server</artifactId>
<version>1.0.0-SNAPSHOT</version>
<name>deeplearning4j-json-server</name>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-json-client</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-json-server</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-parallel-wrapper</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -0,0 +1,205 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.remote.serving.SameDiffServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
/**
*
* @author astoyakin
*/
@Slf4j
@NoArgsConstructor
public class DL4jServlet<I,O> extends SameDiffServlet<I,O> {
protected ParallelInference parallelInference;
protected Model model;
protected boolean parallelEnabled = true;
public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter,
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
super(inferenceAdapter, serializer, deserializer);
this.parallelInference = parallelInference;
this.model = null;
this.parallelEnabled = true;
}
public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter,
@NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer) {
super(inferenceAdapter, serializer, deserializer);
this.model = model;
this.parallelInference = null;
this.parallelEnabled = false;
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
String processorReturned = "";
String path = request.getPathInfo();
if (path.equals(SERVING_ENDPOINT)) {
val contentType = request.getContentType();
if (validateRequest(request,response)) {
val stream = request.getInputStream();
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
char[] charBuffer = new char[128];
int bytesRead = -1;
val buffer = new StringBuilder();
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
buffer.append(charBuffer, 0, bytesRead);
}
val requestString = buffer.toString();
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
O result = null;
if (parallelEnabled) {
// process result
result = inferenceAdapter.apply(parallelInference.output(mds.getFeatures(), mds.getFeaturesMaskArrays()));
}
else {
synchronized(this) {
if (model instanceof ComputationGraph)
result = inferenceAdapter.apply(((ComputationGraph)model).output(false, mds.getFeatures(), mds.getFeaturesMaskArrays()));
else if (model instanceof MultiLayerNetwork) {
Preconditions.checkArgument(mds.getFeatures().length > 1 || (mds.getFeaturesMaskArrays() != null && mds.getFeaturesMaskArrays().length > 1),
"Input data for MultilayerNetwork is invalid!");
result = inferenceAdapter.apply(((MultiLayerNetwork) model).output(mds.getFeatures()[0], false,
mds.getFeaturesMaskArrays() != null ? mds.getFeaturesMaskArrays()[0] : null, null));
}
}
}
processorReturned = serializer.serialize(result);
}
} else {
// we return error otherwise
sendError(request.getRequestURI(), response);
}
try {
val out = response.getWriter();
out.write(processorReturned);
} catch (IOException e) {
log.error(e.getMessage());
}
}
/**
* Creates servlet to serve models
*
* @param <I> type of Input class
* @param <O> type of Output class
*
* @author raver119@gmail.com
* @author astoyakin
*/
public static class Builder<I,O> {
private ParallelInference pi;
private Model model;
private InferenceAdapter<I, O> inferenceAdapter;
private JsonSerializer<O> serializer;
private JsonDeserializer<I> deserializer;
private int port;
private boolean parallelEnabled = true;
public Builder(@NonNull ParallelInference pi) {
this.pi = pi;
}
public Builder(@NonNull Model model) {
this.model = model;
}
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
this.inferenceAdapter = inferenceAdapter;
return this;
}
/**
* This method is required to specify serializer
*
* @param serializer
* @return
*/
public Builder<I,O> serializer(@NonNull JsonSerializer<O> serializer) {
this.serializer = serializer;
return this;
}
/**
* This method allows to specify deserializer
*
* @param deserializer
* @return
*/
public Builder<I,O> deserializer(@NonNull JsonDeserializer<I> deserializer) {
this.deserializer = deserializer;
return this;
}
/**
* This method allows to specify port
*
* @param port
* @return
*/
public Builder<I,O> port(int port) {
this.port = port;
return this;
}
/**
* This method activates parallel inference
*
* @param parallelEnabled
* @return
*/
public Builder<I,O> parallelEnabled(boolean parallelEnabled) {
this.parallelEnabled = parallelEnabled;
return this;
}
public DL4jServlet<I,O> build() {
return parallelEnabled ? new DL4jServlet<I, O>(pi, inferenceAdapter, serializer, deserializer) :
new DL4jServlet<I, O>(model, inferenceAdapter, serializer, deserializer);
}
}
}

View File

@ -0,0 +1,392 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.adapters.InputAdapter;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.remote.SameDiffJsonModelServer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import java.util.List;
/**
* This class provides JSON-based model serving ability for Deeplearning4j/SameDiff models
*
* Server url will be http://0.0.0.0:{port}>/v1/serving
* Server only accepts POST requests
*
* @param <I> type of the input class, i.e. String
* @param <O> type of the output class, i.e. Sentiment
*
* @author raver119@gmail.com
* @author astoyakin
*/
public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
// all serving goes through ParallelInference
protected ParallelInference parallelInference;
protected ModelAdapter<O> modelAdapter;
// actual models
protected ComputationGraph cgModel;
protected MultiLayerNetwork mlnModel;
// service stuff
protected InferenceMode inferenceMode;
protected int numWorkers;
protected boolean enabledParallel = true;
protected JsonModelServer(@NonNull SameDiff sdModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, String[] orderedOutputNodes) {
super(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
}
protected JsonModelServer(@NonNull ComputationGraph cgModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
super(inferenceAdapter, serializer, deserializer, port);
this.cgModel = cgModel;
this.inferenceMode = inferenceMode;
this.numWorkers = numWorkers;
}
protected JsonModelServer(@NonNull MultiLayerNetwork mlnModel, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port, @NonNull InferenceMode inferenceMode, int numWorkers) {
super(inferenceAdapter, serializer, deserializer, port);
this.mlnModel = mlnModel;
this.inferenceMode = inferenceMode;
this.numWorkers = numWorkers;
}
protected JsonModelServer(@NonNull ParallelInference pi, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> serializer, JsonDeserializer<I> deserializer, int port) {
super(inferenceAdapter, serializer, deserializer, port);
this.parallelInference = pi;
}
/**
* This method stops server
*
* @throws Exception
*/
@Override
public void stop() throws Exception {
if (parallelInference != null)
parallelInference.shutdown();
super.stop();
}
/**
* This method starts server
* @throws Exception
*/
@Override
public void start() throws Exception {
// if we're just serving sdModel - we'll just call super. no dl4j functionality required in this case
if (sdModel != null) {
super.start();
return;
}
Preconditions.checkArgument(cgModel != null || mlnModel != null, "Model serving requires either MultilayerNetwork or ComputationGraph defined");
val model = cgModel != null ? (Model) cgModel : (Model) mlnModel;
// PI construction is optional, since we can have it defined
if (enabledParallel) {
if (parallelInference == null) {
Preconditions.checkArgument(numWorkers >= 1, "Number of workers should be >= 1, got " + numWorkers + " instead");
parallelInference = new ParallelInference.Builder(model)
.inferenceMode(inferenceMode)
.workers(numWorkers)
.loadBalanceMode(LoadBalanceMode.FIFO)
.batchLimit(16)
.queueLimit(128)
.build();
}
servingServlet = new DL4jServlet.Builder<I, O>(parallelInference)
.parallelEnabled(true)
.serializer(serializer)
.deserializer(deserializer)
.inferenceAdapter(inferenceAdapter)
.build();
}
else {
servingServlet = new DL4jServlet.Builder<I, O>(model)
.parallelEnabled(false)
.serializer(serializer)
.deserializer(deserializer)
.inferenceAdapter(inferenceAdapter)
.build();
}
start(port, servingServlet);
}
/**
* Creates servlet to serve different types of models
*
* @param <I> type of Input class
* @param <O> type of Output class
*
* @author raver119@gmail.com
* @author astoyakin
*/
public static class Builder<I,O> {
private SameDiff sdModel;
private ComputationGraph cgModel;
private MultiLayerNetwork mlnModel;
private ParallelInference pi;
private String[] orderedInputNodes;
private String[] orderedOutputNodes;
private InferenceAdapter<I, O> inferenceAdapter;
private JsonSerializer<O> serializer;
private JsonDeserializer<I> deserializer;
private InputAdapter<I> inputAdapter;
private OutputAdapter<O> outputAdapter;
private int port;
private boolean parallelMode = true;
// these fields actually require defaults
private InferenceMode inferenceMode = InferenceMode.BATCHED;
private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices();
public Builder(@NonNull SameDiff sdModel) {
this.sdModel = sdModel;
}
public Builder(@NonNull MultiLayerNetwork mlnModel) {
this.mlnModel = mlnModel;
}
public Builder(@NonNull ComputationGraph cgModel) {
this.cgModel = cgModel;
}
public Builder(@NonNull ParallelInference pi) {
this.pi = pi;
}
/**
* This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type
* @param inferenceAdapter
* @return
*/
public Builder<I,O> inferenceAdapter(@NonNull InferenceAdapter<I,O> inferenceAdapter) {
this.inferenceAdapter = inferenceAdapter;
return this;
}
/**
* This method allows you to specify InputAdapter to be used for inference
*
* PLEASE NOTE: This method is optional, and will require OutputAdapter<O> defined
* @param inputAdapter
* @return
*/
public Builder<I,O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
this.inputAdapter = inputAdapter;
return this;
}
/**
* This method allows you to specify OutputtAdapter to be used for inference
*
* PLEASE NOTE: This method is optional, and will require InputAdapter<I> defined
* @param outputAdapter
* @return
*/
public Builder<I,O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
this.outputAdapter = outputAdapter;
return this;
}
/**
* This method allows you to specify serializer
*
* @param serializer
* @return
*/
public Builder<I,O> outputSerializer(@NonNull JsonSerializer<O> serializer) {
this.serializer = serializer;
return this;
}
/**
* This method allows you to specify deserializer
*
* @param deserializer
* @return
*/
public Builder<I,O> inputDeserializer(@NonNull JsonDeserializer<I> deserializer) {
this.deserializer = deserializer;
return this;
}
/**
* This method allows you to specify inference mode for parallel mode. See {@link InferenceMode} for more details
*
* @param inferenceMode
* @return
*/
public Builder<I,O> inferenceMode(@NonNull InferenceMode inferenceMode) {
this.inferenceMode = inferenceMode;
return this;
}
/**
* This method allows you to specify number of worker threads for ParallelInference
*
* @param numWorkers
* @return
*/
public Builder<I,O> numWorkers(int numWorkers) {
this.numWorkers = numWorkers;
return this;
}
/**
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedInputNodes(String... args) {
orderedInputNodes = args;
return this;
}
/**
* This method allows you to specify the order in which the inputs should be mapped to the model placeholder arrays. This is only required for {@link SameDiff} models, not {@link MultiLayerNetwork} or {@link ComputationGraph} models
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedInputNodes(@NonNull List<String> args) {
orderedInputNodes = args.toArray(new String[args.size()]);
return this;
}
/**
* This method allows you to specify output nodes
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedOutputNodes(String... args) {
Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element");
orderedOutputNodes = args;
return this;
}
/**
* This method allows you to specify output nodes
*
* PLEASE NOTE: this argument only used for SameDiff models
* @param args
* @return
*/
public Builder<I,O> orderedOutputNodes(@NonNull List<String> args) {
Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element");
orderedOutputNodes = args.toArray(new String[args.size()]);
return this;
}
/**
* This method allows you to specify http port
*
* PLEASE NOTE: port must be free and be in range regular TCP/IP ports range
* @param port
* @return
*/
public Builder<I,O> port(int port) {
this.port = port;
return this;
}
/**
* This method switches on ParallelInference usage
* @param - true - to use ParallelInference, false - to use ComputationGraph or
* MultiLayerNetwork directly
*
* PLEASE NOTE: this doesn't apply to SameDiff models
*
* @throws Exception
*/
public Builder<I,O> parallelMode(boolean enable) {
this.parallelMode = enable;
return this;
}
public JsonModelServer<I,O> build() {
if (inferenceAdapter == null) {
if (inputAdapter != null && outputAdapter != null) {
inferenceAdapter = new InferenceAdapter<I, O>() {
@Override
public MultiDataSet apply(I input) {
return inputAdapter.apply(input);
}
@Override
public O apply(INDArray... outputs) {
return outputAdapter.apply(outputs);
}
};
} else
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
}
if (sdModel != null) {
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "For SameDiff model serving OutputNodes should be defined");
return new JsonModelServer<I, O>(sdModel, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
} else if (cgModel != null)
return new JsonModelServer<I,O>(cgModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
else if (mlnModel != null)
return new JsonModelServer<I,O>(mlnModel, inferenceAdapter, serializer, deserializer, port, inferenceMode, numWorkers);
else if (pi != null)
return new JsonModelServer<I,O>(pi, inferenceAdapter, serializer, deserializer, port);
else
throw new IllegalStateException("No models were defined for JsonModelServer");
}
}
}

View File

@ -0,0 +1,749 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.remote.helpers.House;
import org.deeplearning4j.remote.helpers.HouseToPredictedPriceAdapter;
import org.deeplearning4j.remote.helpers.PredictedPrice;
import org.junit.After;
import org.junit.Test;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.remote.clients.JsonRemoteInference;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import static org.deeplearning4j.parallelism.inference.InferenceMode.INPLACE;
import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL;
import static org.junit.Assert.*;
@Slf4j
public class JsonModelServerTest {
private static final MultiLayerNetwork model;
private final int PORT = 18080;
static {
val conf = new NeuralNetConfiguration.Builder()
.seed(119)
.updater(new Adam(0.119f))
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(10).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nIn(10).nOut(1).build())
.build();
model = new MultiLayerNetwork(conf);
model.init();
}
@After
public void pause() throws Exception {
// TODO: the same port was used in previous test and not accessible immediately. Might be better solution.
TimeUnit.SECONDS.sleep(2);
}
@Test
public void testStartStopParallel() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(SEQUENTIAL)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.port(PORT+1)
.build();
try {
serverDL.start();
serverSD.start();
val clientDL = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
PredictedPrice price = clientDL.predict(house);
long timeStart = System.currentTimeMillis();
price = clientDL.predict(house);
long timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
val clientSD = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
.build();
PredictedPrice price2 = clientSD.predict(house);
timeStart = System.currentTimeMillis();
price = clientSD.predict(house);
timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 3.0, price.getPrice(), 1e-5);
}
finally {
serverSD.stop();
serverDL.stop();
}
}
@Test
public void testStartStopSequential() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val serverDL = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(SEQUENTIAL)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
val serverSD = new JsonModelServer.Builder<House, PredictedPrice>(sd)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.port(PORT+1)
.build();
serverDL.start();
serverDL.stop();
serverSD.start();
serverSD.stop();
}
@Test
public void basicServingTestForSD() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 1,4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new JsonModelServer.Builder<House, PredictedPrice>(sd)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
val timeStart = System.currentTimeMillis();
price = client.predict(house);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) district + 1.0f, price.getPrice(), 1e-5);
}
finally {
server.stop();
}
}
@Test
public void basicServingTestForDLSynchronized() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(INPLACE)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house1 = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
House house2 = House.builder().area(50).bathrooms(1).bedrooms(2).district(district).build();
House house3 = House.builder().area(80).bathrooms(1).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house1);
val timeStart = System.currentTimeMillis();
PredictedPrice price1 = client.predict(house1);
PredictedPrice price2 = client.predict(house2);
PredictedPrice price3 = client.predict(house3);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
} finally {
server.stop();
}
}
@Test
public void basicServingTestForDL() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.numWorkers(1)
.inferenceMode(SEQUENTIAL)
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.parallelMode(false)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
val timeStart = System.currentTimeMillis();
price = client.predict(house);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) 0.421444, price.getPrice(), 1e-5);
} finally {
server.stop();
}
}
@Test
public void testDeserialization_1() {
String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}";
val deserializer = new House.HouseDeserializer();
val result = deserializer.deserialize(request);
assertEquals(2, result.getDistrict());
assertEquals(100, result.getArea());
assertEquals(2, result.getBathrooms());
assertEquals(3, result.getBedrooms());
}
@Test
public void testDeserialization_2() {
String request = "{\"price\":1}";
val deserializer = new PredictedPrice.PredictedPriceDeserializer();
val result = deserializer.deserialize(request);
assertEquals(1.0, result.getPrice(), 1e-4);
}
@Test(expected = NullPointerException.class)
public void negativeServingTest_1() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(null)
.port(18080)
.build();
}
@Test //(expected = NullPointerException.class)
public void negativeServingTest_2() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.port(PORT)
.build();
}
@Test(expected = IOException.class)
public void negativeServingTest_3() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.inferenceMode(SEQUENTIAL)
.numWorkers(1)
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
@Override
public PredictedPrice deserialize(String json) {
return null;
}
})
.endpointAddress("http://localhost:18080/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
} finally {
server.stop();
}
}
@Test
public void asyncServingTest() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.inferenceMode(SEQUENTIAL)
.numWorkers(1)
.port(PORT)
.build();
try {
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
val timeStart = System.currentTimeMillis();
Future<PredictedPrice> price = client.predictAsync(house);
assertNotNull(price);
assertEquals((float) 0.421444, price.get().getPrice(), 1e-5);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
}
finally {
server.stop();
}
}
@Test
public void negativeAsyncTest() throws Exception {
val server = new JsonModelServer.Builder<House, PredictedPrice>(model)
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.inferenceMode(InferenceMode.BATCHED)
.numWorkers(1)
.port(PORT)
.build();
try {
server.start();
// Fake deserializer to test failure
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
@Override
public PredictedPrice deserialize(String json) {
return null;
}
})
.endpointAddress("http://localhost:" + PORT + "/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
val timeStart = System.currentTimeMillis();
try {
Future<PredictedPrice> price = client.predictAsync(house);
assertNotNull(price);
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
} catch (ExecutionException e) {
assertTrue(e.getMessage().contains("Deserialization failed"));
}
} finally {
server.stop();
}
}
@Test
public void testSameDiffMnist() throws Exception {
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
val server = new JsonModelServer.Builder<float[], Integer>(sd)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("in")
.orderedOutputNodes("softmax")
.port(PORT+1)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT+1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try{
server.start();
for( int i=0; i<10; i++ ){
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28*28);
INDArray exp = sd.output(Collections.singletonMap("in", f), "softmax").get("softmax");
float[] fArr = f.toFloatVector();
int out = client.predict(fArr);
assertEquals(exp.argMax().getInt(0), out);
}
} finally {
server.stop();
}
}
@Test
public void testMlnMnist() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(10).build())
.layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
val server = new JsonModelServer.Builder<float[], Integer>(net)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("in")
.orderedOutputNodes("softmax")
.port(PORT + 1)
.inferenceMode(SEQUENTIAL)
.numWorkers(2)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try {
server.start();
for (int i = 0; i < 10; i++) {
INDArray f = Nd4j.rand(DataType.FLOAT, 1, 28 * 28);
INDArray exp = net.output(f);
float[] fArr = f.toFloatVector();
int out = client.predict(fArr);
assertEquals(exp.argMax().getInt(0), out);
}
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
@Test
public void testCompGraph() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("input1", "input2")
.addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1")
.addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2")
.addVertex("merge", new MergeVertex(), "L1", "L2")
.addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge")
.setOutputs("out")
.build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
val server = new JsonModelServer.Builder<float[], Integer>(net)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("in")
.orderedOutputNodes("softmax")
.port(PORT + 1)
.inferenceMode(SEQUENTIAL)
.numWorkers(2)
.parallelMode(false)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try {
server.start();
//client.predict(new float[]{0.0f, 1.0f, 2.0f});
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
@Test
public void testCompGraph_1() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(0.01))
.graphBuilder()
.addInputs("input")
.addLayer("L1", new DenseLayer.Builder().nIn(8).nOut(4).build(), "input")
.addLayer("out1", new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(4).nOut(3).build(), "L1")
.addLayer("out2", new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MSE)
.nIn(4).nOut(2).build(), "L1")
.setOutputs("out1","out2")
.build();
final ComputationGraph net = new ComputationGraph(conf);
net.init();
val server = new JsonModelServer.Builder<float[], Integer>(net)
.outputSerializer( new IntSerde())
.inputDeserializer(new FloatSerde())
.inferenceAdapter(new InferenceAdapter<float[], Integer>() {
@Override
public MultiDataSet apply(float[] input) {
return new MultiDataSet(Nd4j.create(input, 1, input.length), null);
}
@Override
public Integer apply(INDArray... nnOutput) {
return nnOutput[0].argMax().getInt(0);
}
})
.orderedInputNodes("input")
.orderedOutputNodes("out")
.port(PORT + 1)
.inferenceMode(SEQUENTIAL)
.numWorkers(2)
.parallelMode(false)
.build();
val client = JsonRemoteInference.<float[], Integer>builder()
.endpointAddress("http://localhost:" + (PORT + 1) + "/v1/serving")
.outputDeserializer(new IntSerde())
.inputSerializer( new FloatSerde())
.build();
try {
server.start();
val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
assertNotNull(result);
} catch (Exception e){
e.printStackTrace();
throw e;
} finally {
server.stop();
}
}
private static class FloatSerde implements JsonSerializer<float[]>, JsonDeserializer<float[]>{
private final ObjectMapper om = new ObjectMapper();
@Override
public float[] deserialize(String json) {
try {
return om.readValue(json, FloatHolder.class).getFloats();
} catch (IOException e){
throw new RuntimeException(e);
}
}
@Override
public String serialize(float[] o) {
try{
return om.writeValueAsString(new FloatHolder(o));
} catch (IOException e){
throw new RuntimeException(e);
}
}
//Use float holder so Jackson does ser/de properly (no "{}" otherwise)
@AllArgsConstructor @NoArgsConstructor @Data
private static class FloatHolder {
private float[] floats;
}
}
private static class IntSerde implements JsonSerializer<Integer>, JsonDeserializer<Integer> {
private final ObjectMapper om = new ObjectMapper();
@Override
public Integer deserialize(String json) {
try {
return om.readValue(json, Integer.class);
} catch (IOException e){
throw new RuntimeException(e);
}
}
@Override
public String serialize(Integer o) {
try{
return om.writeValueAsString(o);
} catch (IOException e){
throw new RuntimeException(e);
}
}
}
}

View File

@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote;
import lombok.val;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.HttpClientBuilder;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class ServletTest {
private JsonModelServer server;
@Before
public void setUp() throws Exception {
val sd = SameDiff.create();
server = new JsonModelServer.Builder<String,String>(sd)
.port(8080)
.inferenceAdapter(new InferenceAdapter<String, String>() {
@Override
public MultiDataSet apply(String input) {
return null;
}
@Override
public String apply(INDArray... nnOutput) {
return null;
}
})
.outputSerializer(new JsonSerializer<String>() {
@Override
public String serialize(String o) {
return "";
}
})
.inputDeserializer(new JsonDeserializer<String>() {
@Override
public String deserialize(String json) {
return "";
}
})
.orderedInputNodes("input")
.orderedOutputNodes("output")
.build();
server.start();
//server.join();
}
@After
public void tearDown() throws Exception {
server.stop();
}
@Test
public void getEndpoints() throws IOException {
val request = new HttpGet( "http://localhost:8080/v1" );
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(200, response.getStatusLine().getStatusCode());
}
@Test
public void testContentTypeGet() throws IOException {
val request = new HttpGet( "http://localhost:8080/v1" );
request.setHeader("Content-type", "text/plain");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(415, response.getStatusLine().getStatusCode());
}
@Test
public void testContentTypePost() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving");
request.setHeader("Content-type", "text/plain");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(415, response.getStatusLine().getStatusCode());
}
@Test
public void postForServing() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving");
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(500, response.getStatusLine().getStatusCode());
}
@Test
public void testNotFoundPost() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving/some");
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(404, response.getStatusLine().getStatusCode());
}
@Test
public void testNotFoundGet() throws Exception {
val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" );
requestGet.setHeader("Content-type", "application/json");
val responseGet = HttpClientBuilder.create().build().execute( requestGet );
assertEquals(404, responseGet.getStatusLine().getStatusCode());
}
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote.helpers;
import com.google.gson.Gson;
import lombok.*;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class House {
private int district;
private int bedrooms;
private int bathrooms;
private int area;
public static class HouseSerializer implements JsonSerializer<House> {
@Override
public String serialize(@NonNull House o) {
return new Gson().toJson(o);
}
}
public static class HouseDeserializer implements JsonDeserializer<House> {
@Override
public House deserialize(@NonNull String json) {
return new Gson().fromJson(json, House.class);
}
}
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote.helpers;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
@Slf4j
public class HouseToPredictedPriceAdapter implements InferenceAdapter<House, PredictedPrice> {
@Override
public MultiDataSet apply(@NonNull House input) {
// we just create vector array with shape[4] and assign it's value to the district value
return new MultiDataSet(Nd4j.create(DataType.FLOAT, 1, 4).assign(input.getDistrict()), null);
}
@Override
public PredictedPrice apply(INDArray... nnOutput) {
return new PredictedPrice(nnOutput[0].getFloat(0));
}
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.remote.helpers;
import com.google.gson.Gson;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class PredictedPrice {
private float price;
public static class PredictedPriceSerializer implements JsonSerializer<PredictedPrice> {
@Override
public String serialize(@NonNull PredictedPrice o) {
return new Gson().toJson(o);
}
}
public static class PredictedPriceDeserializer implements JsonDeserializer<PredictedPrice> {
@Override
public PredictedPrice deserialize(@NonNull String json) {
return new Gson().fromJson(json, PredictedPrice.class);
}
}
}

View File

@ -0,0 +1,48 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.eclipse.jetty" level="WARN" />
<logger name="org.apache.catalina.core" level="WARN" />
<logger name="org.springframework" level="WARN" />
<logger name="org.nd4j" level="DEBUG" />
<logger name="org.deeplearning4j" level="INFO" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -0,0 +1,30 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>pom</packaging>
<modules>
<module>deeplearning4j-json-server</module>
</modules>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-remote</artifactId>
<version>1.0.0-SNAPSHOT</version>
<name>deeplearning4j-remote</name>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
</profile>
</profiles>
</project>

View File

@ -21,7 +21,6 @@ import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter; import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;

View File

@ -22,7 +22,6 @@ import lombok.val;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter; import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;

View File

@ -144,6 +144,7 @@
<module>dl4j-perf</module> <module>dl4j-perf</module>
<module>dl4j-integration-tests</module> <module>dl4j-integration-tests</module>
<module>deeplearning4j-common</module> <module>deeplearning4j-common</module>
<module>deeplearning4j-remote</module>
</modules> </modules>
<dependencyManagement> <dependencyManagement>

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.adapters;
/**
* This interface describes methods needed to convert custom JVM objects to INDArrays, suitable for feeding neural networks
*
* @param <I> type of the Input for the model. I.e. String for raw text
* @param <O> type of the Output for the model, I.e. Sentiment, for Text->Sentiment extraction
*
* @author raver119@gmail.com
*/
public interface InferenceAdapter<I, O> extends InputAdapter<I>, OutputAdapter<O> {
}

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.adapters;
import org.nd4j.linalg.dataset.MultiDataSet;
/**
* This interface describes method for transformation from object of type I to MultiDataSet.
*
*/
public interface InputAdapter<I> {
/**
* This method converts input object to MultiDataSet
* @param input
* @return
*/
MultiDataSet apply(I input);
}

View File

@ -14,15 +14,15 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.nn.api; package org.nd4j.adapters;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable; import java.io.Serializable;
/** /**
* This interface describes entity used to conver neural network output to specified class. * This interface describes entity used to convert neural network output to specified class.
* I.e. INDArray -> int[] on the fly. * I.e. INDArray -> int[] or INDArray -> Sentiment on the fly
* *
* PLEASE NOTE: Implementation will be used in workspace environment to avoid additional allocations during inference. * PLEASE NOTE: Implementation will be used in workspace environment to avoid additional allocations during inference.
* This means you shouldn't store or return the INDArrays passed to OutputAdapter.apply(INDArray...) directly. * This means you shouldn't store or return the INDArrays passed to OutputAdapter.apply(INDArray...) directly.

View File

@ -4036,6 +4036,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
public native void printBuffer(); public native void printBuffer();
public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/);
/**
* print element by element consequently in a way they (elements) are stored in physical memory
*/
public native void printLinearBuffer();
/** /**
* prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status
*/ */
@ -7047,9 +7052,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength);
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength);
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength);
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder);
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder);
@ -7894,10 +7899,6 @@ public static final int PREALLOC_SIZE = 33554432;
* Returns the prod of the data * Returns the prod of the data
* up to the given length * up to the given length
*/ */
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length);
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length);
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length);
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length);
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length);
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length);
@ -8745,10 +8746,6 @@ public static final int PREALLOC_SIZE = 33554432;
* @param originalTadNum the tad number for the reduced version of the problem * @param originalTadNum the tad number for the reduced version of the problem
*/ */
/**
* Returns the prod of the data
* up to the given length
*/
/** /**
* Returns the prod of the data * Returns the prod of the data

View File

@ -4036,6 +4036,11 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
public native void printBuffer(); public native void printBuffer();
public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/); public native void printBuffer(@Cast("char*") BytePointer msg/*=nullptr*/, @Cast("Nd4jLong") long limit/*=-1*/, @Cast("const bool") boolean sync/*=true*/);
/**
* print element by element consequently in a way they (elements) are stored in physical memory
*/
public native void printLinearBuffer();
/** /**
* prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status * prints _buffer (if host = true) or _bufferD (if host = false) as it is, that is in current state without checking buffer status
*/ */
@ -7047,9 +7052,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength); @Namespace("shape") public static native int tadIndexForLinear(int linearIndex, int tadLength);
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength); @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") LongBuffer shapeInfo, IntBuffer dimension, int dimensionLength);
@Namespace("shape") public static native int tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength); @Namespace("shape") public static native @Cast("Nd4jLong") long tadLength(@Cast("Nd4jLong*") long[] shapeInfo, int[] dimension, int dimensionLength);
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongPointer oldShape, int newRank, @Cast("Nd4jLong*") LongPointer newShape, @Cast("bool") boolean isFOrder);
@Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder); @Namespace("shape") public static native @Cast("bool") boolean canReshape(int oldRank, @Cast("Nd4jLong*") LongBuffer oldShape, int newRank, @Cast("Nd4jLong*") LongBuffer newShape, @Cast("bool") boolean isFOrder);
@ -7894,10 +7899,6 @@ public static final int PREALLOC_SIZE = 33554432;
* Returns the prod of the data * Returns the prod of the data
* up to the given length * up to the given length
*/ */
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongPointer data, int length);
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") LongBuffer data, int length);
@Namespace("shape") public static native int prod(@Cast("Nd4jLong*") long[] data, int length);
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongPointer data, int length);
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") LongBuffer data, int length);
@Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length); @Namespace("shape") public static native @Cast("Nd4jLong") long prodLong(@Cast("const Nd4jLong*") long[] data, int length);
@ -8745,10 +8746,6 @@ public static final int PREALLOC_SIZE = 33554432;
* @param originalTadNum the tad number for the reduced version of the problem * @param originalTadNum the tad number for the reduced version of the problem
*/ */
/**
* Returns the prod of the data
* up to the given length
*/
/** /**
* Returns the prod of the data * Returns the prod of the data

View File

@ -50,6 +50,8 @@
<artifactId>httpmime</artifactId> <artifactId>httpmime</artifactId>
<version>${httpmime.version}</version> <version>${httpmime.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.mashape.unirest</groupId> <groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId> <artifactId>unirest-java</artifactId>

View File

@ -54,11 +54,13 @@
<artifactId>httpmime</artifactId> <artifactId>httpmime</artifactId>
<version>${httpmime.version}</version> <version>${httpmime.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.mashape.unirest</groupId> <groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId> <artifactId>unirest-java</artifactId>
<version>${unirest.version}</version> <version>${unirest.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-jackson</artifactId> <artifactId>nd4j-jackson</artifactId>

View File

@ -21,13 +21,15 @@ import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException; import com.beust.jcommander.ParameterException;
import com.beust.jcommander.Parameters; import com.beust.jcommander.Parameters;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
import com.mashape.unirest.http.HttpResponse;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import com.mashape.unirest.http.Unirest; import com.mashape.unirest.http.Unirest;
import io.aeron.Aeron; import io.aeron.Aeron;
import io.aeron.driver.MediaDriver; import io.aeron.driver.MediaDriver;
import io.aeron.driver.ThreadingMode; import io.aeron.driver.ThreadingMode;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val;
import org.agrona.CloseHelper; import org.agrona.CloseHelper;
import org.agrona.concurrent.BusySpinIdleStrategy; import org.agrona.concurrent.BusySpinIdleStrategy;
import org.json.JSONObject; import org.json.JSONObject;
@ -49,7 +51,6 @@ import org.nd4j.parameterserver.updater.SoftSyncParameterUpdater;
import org.nd4j.parameterserver.updater.SynchronousParameterUpdater; import org.nd4j.parameterserver.updater.SynchronousParameterUpdater;
import org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage; import org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage;
import org.nd4j.parameterserver.util.CheckSocket; import org.nd4j.parameterserver.util.CheckSocket;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -342,7 +343,7 @@ public class ParameterServerSubscriber implements AutoCloseable {
JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState)); JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState));
String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort, String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort,
streamId); streamId);
HttpResponse<String> entity = Unirest.post(url).header("Content-Type", "application/json") val entity = Unirest.post(url).header("Content-Type", "application/json")
.body(jsonObject).asString(); .body(jsonObject).asString();
} catch (Exception e) { } catch (Exception e) {
failCount.incrementAndGet(); failCount.incrementAndGet();

View File

View File

@ -19,13 +19,13 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent> <parent>
<artifactId>nd4j-serde</artifactId> <artifactId>nd4j-remote</artifactId>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<artifactId>nd4j-grpc</artifactId> <artifactId>nd4j-grpc-client</artifactId>
<name>nd4j-grpc</name> <name>nd4j-grpc</name>
<!-- FIXME change it to the project's website --> <!-- FIXME change it to the project's website -->

View File

@ -0,0 +1,70 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>jar</packaging>
<parent>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-remote</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>nd4j-json-client</artifactId>
<name>nd4j-json-client</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
<version>${unirest.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>jackson</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>testresources</id>
<!-- Put nd4j-native in profile so that CUDA-only builds succeed -->
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<classifier>${javacpp.platform}</classifier>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -0,0 +1,153 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients;
import com.mashape.unirest.http.HttpResponse;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.http.exceptions.UnirestException;
import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.json.JSONObject;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* This class provides remote inference functionality via JSON-powered REST APIs.
*
* Basically we assume that there's remote JSON server available (on bare metal or in k8s/swarm/whatever cluster), and with proper serializers/deserializers provided we can issue REST requests and get back responses.
* So, in this way application logic can be separated from DL logic.
*
* You just need to provide serializer/deserializer and address of the REST server, i.e. "http://model:8080/v1/serving"
*
* @param <I> type of the input class, i.e. String
* @param <O> type of the output class, i.e. Sentiment
*
* @author raver119@gmail.com
*/
@Slf4j
public class JsonRemoteInference<I, O> {
private String endpointAddress;
private JsonSerializer<I> serializer;
private JsonDeserializer<O> deserializer;
@Builder
public JsonRemoteInference(@NonNull String endpointAddress, @NonNull JsonSerializer<I> inputSerializer, @NonNull JsonDeserializer<O> outputDeserializer) {
this.endpointAddress = endpointAddress;
this.serializer = inputSerializer;
this.deserializer = outputDeserializer;
}
private O processResponse(HttpResponse<String> response) throws IOException {
if (response.getStatus() != 200)
throw new IOException("Inference request returned bad error code: " + response.getStatus());
O result = deserializer.deserialize(response.getBody());
if (result == null) {
throw new IOException("Deserialization failed!");
}
return result;
}
/**
* This method does remote inference in a blocking way
*
* @param input
* @return
* @throws IOException
*/
public O predict(I input) throws IOException {
try {
val stringResult = Unirest.post(endpointAddress)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.body(new JSONObject(serializer.serialize(input))).asString();
return processResponse(stringResult);
} catch (UnirestException e) {
throw new IOException(e);
}
}
/**
* This method does remote inference in asynchronous way, returning Future instead
* @param input
* @return
*/
public Future<O> predictAsync(I input) {
val stringResult = Unirest.post(endpointAddress)
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.body(new JSONObject(serializer.serialize(input))).asStringAsync();
return new InferenceFuture(stringResult);
}
/**
* This class holds a Future of the object returned by remote inference server
*/
private class InferenceFuture implements Future<O> {
private Future<HttpResponse<String>> unirestFuture;
private InferenceFuture(@NonNull Future<HttpResponse<String>> future) {
this.unirestFuture = future;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return unirestFuture.cancel(mayInterruptIfRunning);
}
@Override
public boolean isCancelled() {
return unirestFuture.isCancelled();
}
@Override
public boolean isDone() {
return unirestFuture.isDone();
}
@Override
public O get() throws InterruptedException, ExecutionException {
val stringResult = unirestFuture.get();
try {
return processResponse(stringResult);
} catch (IOException e) {
throw new ExecutionException(e);
}
}
@Override
public O get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
val stringResult = unirestFuture.get(timeout, unit);
try {
return processResponse(stringResult);
} catch (IOException e) {
throw new ExecutionException(e);
}
}
}
}

View File

@ -0,0 +1,33 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde;
/**
* This interface describes basic JSON deserializer interface used for JsonRemoteInference
* @param <T> type of the deserializable class
*
* @author raver119@gmail.com
*/
public interface JsonDeserializer<T> {
/**
* This method serializes given object into JSON-string
* @param json string containing JSON representation of the object
* @return
*/
T deserialize(String json);
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde;
/**
* This interface describes basic JSON serializer interface used for JsonRemoteInference
* @param <T> type of the serializable class
*
* @author raver119@gmail.com
*/
public interface JsonSerializer<T> {
/**
* This method serializes given object into JSON-string
*
* @param o object to be serialized
* @return
*/
String serialize(T o);
}

View File

@ -0,0 +1,46 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.NonNull;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
public abstract class AbstractSerDe<T> implements JsonDeserializer<T>, JsonSerializer<T> {
protected ObjectMapper objectMapper = new ObjectMapper();
protected String serializeClass(@NonNull T obj) {
try {
return objectMapper.writeValueAsString(obj);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
protected T deserializeClass(@NonNull String json, @NonNull Class<T> cls) {
try {
return objectMapper.readValue(json, cls);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.NonNull;
/**
* This class provides JSON ser/de for Java Boolean. Single value only.
*/
public class BooleanSerde extends AbstractSerDe<Boolean> {
@Override
public Boolean deserialize(@NonNull String json) {
return deserializeClass(json, Boolean.class);
}
@Override
public String serialize(@NonNull Boolean o) {
return serializeClass(o);
}
}

View File

@ -0,0 +1,41 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.*;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
/**
* This class provides JSON ser/de for Java double[]
*/
public class DoubleArraySerde extends AbstractSerDe<double[]> {
@Override
public String serialize(@NonNull double[] data) {
return serializeClass(data);
}
@Override
public double[] deserialize(@NonNull String json) {
return deserializeClass(json, double[].class);
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.NonNull;
/**
* This class provides JSON ser/de for Java Double. Single value only.
*/
public class DoubleSerde extends AbstractSerDe<Double> {
@Override
public Double deserialize(@NonNull String json) {
return deserializeClass(json, Double.class);
}
@Override
public String serialize(@NonNull Double o) {
return serializeClass(o);
}
}

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.*;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
/**
* This class provides JSON ser/de for Java float[]
*/
public class FloatArraySerde extends AbstractSerDe<float[]> {
@Override
public String serialize(@NonNull float[] data) {
return serializeClass(data);
}
@Override
public float[] deserialize(@NonNull String json) {
return deserializeClass(json, float[].class);
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.NonNull;
/**
* This class provides JSON ser/de for Java Float. Single value only.
*/
public class FloatSerde extends AbstractSerDe<Float> {
@Override
public Float deserialize(@NonNull String json) {
return deserializeClass(json, Float.class);
}
@Override
public String serialize(@NonNull Float o) {
return serializeClass(o);
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.NonNull;
/**
* This class provides JSON ser/de for Java Integer. Single value only.
*/
public class IntegerSerde extends AbstractSerDe<Integer> {
@Override
public Integer deserialize(@NonNull String json) {
return deserializeClass(json, Integer.class);
}
@Override
public String serialize(@NonNull Integer o) {
return serializeClass(o);
}
}

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.clients.serde.impl;
import lombok.NonNull;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
/**
* This class provides fake JSON serializer/deserializer functionality for String.
* It doesn't put any JSON-specific bits into actual string
*/
public class StringSerde extends AbstractSerDe<String> {
@Override
public String serialize(@NonNull String data) {
return serializeClass(data);
}
@Override
public String deserialize(@NonNull String json) {
return deserializeClass(json, String.class);
}
}

View File

@ -0,0 +1,35 @@
## SameDiff model serving
This modules provides JSON-based serving of SameDiff models
## Example
First of all we'll create server instance. Most probably you'll do it in application that will be running in container
```java
val server = SameDiffJsonModelServer.<String, Sentiment>builder()
.adapter(new StringToSentimentAdapter())
.model(mySameDiffModel)
.port(8080)
.serializer(new SentimentSerializer())
.deserializer(new StringDeserializer())
.build();
server.start();
server.join();
```
Now, presumably in some other container, we'll set up remote inference client:
```java
val client = JsonRemoteInference.<String, Sentiment>builder()
.endpointAddress("http://youraddress:8080/v1/serving")
.serializer(new StringSerializer())
.deserializer(new SentimentDeserializer())
.build();
Sentiment result = client.predict(myText);
```
On top of that, there's async call available, for cases when you need to chain multiple requests to one or multiple remote model servers.
```java
Future<Sentiment> result = client.predictAsync(myText);
```

View File

@ -0,0 +1,179 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>jar</packaging>
<parent>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-remote</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>nd4j-json-server</artifactId>
<name>nd4j-json-server</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
<jersey.version>2.26</jersey.version>
</properties>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-json-client</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.core</groupId>
<artifactId>jersey-client</artifactId>
<version>${jersey.version}</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.core</groupId>
<artifactId>jersey-server</artifactId>
<version>${jersey.version}</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
<version>9.4.19.v20190610</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-servlet</artifactId>
<version>9.4.19.v20190610</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.inject</groupId>
<artifactId>jersey-hk2</artifactId>
<version>${jersey.version}</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.media</groupId>
<artifactId>jersey-media-json-processing</artifactId>
<version>${jersey.version}</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.containers</groupId>
<artifactId>jersey-container-servlet-core</artifactId>
<version>${jersey.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>javax.xml.bind</groupId>
<artifactId>jaxb-api</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>com.sun.xml.bind</groupId>
<artifactId>jaxb-impl</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>com.sun.xml.bind</groupId>
<artifactId>jaxb-core</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>javax.activation</groupId>
<artifactId>activation</artifactId>
<version>1.1</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>nd4j-tests-cpu</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>nd4j-tests-cuda</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>testresources</id>
</profile>
</profiles>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>${maven.compiler.source}</source>
<target>${maven.compiler.target}</target>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,288 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.nd4j.adapters.InputAdapter;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.remote.serving.ModelServingServlet;
import org.nd4j.remote.serving.SameDiffServlet;
import java.util.List;
/**
* This class provides JSON-powered model serving functionality for SameDiff graphs.
* Server url will be http://0.0.0.0:{port}>/v1/serving
* Server only accepts POST requests
*
* @param <I> type of the input class, i.e. String
* @param <O> type of the output class, i.e. Sentiment
*
* @author raver119@gmail.com
*/
@Slf4j
public class SameDiffJsonModelServer<I, O> {
protected SameDiff sdModel;
protected final JsonSerializer<O> serializer;
protected final JsonDeserializer<I> deserializer;
protected final InferenceAdapter<I, O> inferenceAdapter;
protected final int port;
// this servlet will be used to serve models
protected ModelServingServlet<I, O> servingServlet;
// HTTP server instance
protected Server server;
// for SameDiff only
protected String[] orderedInputNodes;
protected String[] orderedOutputNodes;
protected SameDiffJsonModelServer(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port) {
Preconditions.checkArgument(port > 0 && port < 65535, "TCP port must be in range of 0..65535");
this.inferenceAdapter = inferenceAdapter;
this.serializer = serializer;
this.deserializer = deserializer;
this.port = port;
}
//@Builder
public SameDiffJsonModelServer(SameDiff sdModel, @NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer, int port, String[] orderedInputNodes, @NonNull String[] orderedOutputNodes) {
this(inferenceAdapter, serializer, deserializer, port);
this.sdModel = sdModel;
this.orderedInputNodes = orderedInputNodes;
this.orderedOutputNodes = orderedOutputNodes;
// TODO: both lists of nodes should be validated, to make sure nodes specified here exist in actual model
if (orderedInputNodes != null) {
// input nodes list might be null. strange but ok
}
Preconditions.checkArgument(orderedOutputNodes != null && orderedOutputNodes.length > 0, "SameDiff serving requires at least 1 output node");
}
protected void start(int port, @NonNull ModelServingServlet<I, O> servlet) throws Exception {
val context = new ServletContextHandler(ServletContextHandler.SESSIONS);
context.setContextPath("/");
server = new Server(port);
server.setHandler(context);
val jerseyServlet = context.addServlet(org.glassfish.jersey.servlet.ServletContainer.class, "/*");
jerseyServlet.setInitOrder(0);
jerseyServlet.setServlet(servlet);
server.start();
}
public void start() throws Exception {
Preconditions.checkArgument(sdModel != null, "SameDiff model wasn't defined");
servingServlet = SameDiffServlet.<I, O>builder()
.sdModel(sdModel)
.serializer(serializer)
.deserializer(deserializer)
.inferenceAdapter(inferenceAdapter)
.orderedInputNodes(orderedInputNodes)
.orderedOutputNodes(orderedOutputNodes)
.build();
start(port, servingServlet);
}
public void join() throws InterruptedException {
Preconditions.checkArgument(server != null, "Model server wasn't started yet");
server.join();
}
public void stop() throws Exception {
//Preconditions.checkArgument(server != null, "Model server wasn't started yet");
server.stop();
}
public static class Builder<I,O> {
private SameDiff sameDiff;
private String[] orderedInputNodes;
private String[] orderedOutputNodes;
private InferenceAdapter<I, O> inferenceAdapter;
private JsonSerializer<O> serializer;
private JsonDeserializer<I> deserializer;
private int port;
private InputAdapter<I> inputAdapter;
private OutputAdapter<O> outputAdapter;
public Builder() {}
public Builder<I,O> sdModel(@NonNull SameDiff sameDiff) {
this.sameDiff = sameDiff;
return this;
}
/**
* This method defines InferenceAdapter implementation, which will be used to convert object of Input type to the set of INDArray(s), and for conversion of resulting INDArray(s) into object of Output type
* @param inferenceAdapter
* @return
*/
public Builder<I,O> inferenceAdapter(InferenceAdapter<I,O> inferenceAdapter) {
this.inferenceAdapter = inferenceAdapter;
return this;
}
/**
* This method allows you to specify InputAdapter to be used for inference
*
* PLEASE NOTE: This method is optional, and will require OutputAdapter<O> defined
* @param inputAdapter
* @return
*/
public Builder<I,O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
this.inputAdapter = inputAdapter;
return this;
}
/**
* This method allows you to specify OutputAdapter to be used for inference
*
* PLEASE NOTE: This method is optional, and will require InputAdapter<I> defined
* @param outputAdapter
* @return
*/
public Builder<I,O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
this.outputAdapter = outputAdapter;
return this;
}
/**
* This method defines JsonSerializer instance to be used to convert object of output type into JSON format, so it could be sent over the wire
*
* @param serializer
* @return
*/
public Builder<I,O> outputSerializer(@NonNull JsonSerializer<O> serializer) {
this.serializer = serializer;
return this;
}
/**
* This method defines JsonDeserializer instance to be used to convert JSON passed through HTTP into actual object of input type, that will be fed into SameDiff model
*
* @param deserializer
* @return
*/
public Builder<I,O> inputDeserializer(@NonNull JsonDeserializer<I> deserializer) {
this.deserializer = deserializer;
return this;
}
/**
* This method defines the order of placeholders to be filled with INDArrays provided by Deserializer
*
* @param args
* @return
*/
public Builder<I,O> orderedInputNodes(String... args) {
orderedInputNodes = args;
return this;
}
/**
* This method defines the order of placeholders to be filled with INDArrays provided by Deserializer
*
* @param args
* @return
*/
public Builder<I,O> orderedInputNodes(@NonNull List<String> args) {
orderedInputNodes = args.toArray(new String[args.size()]);
return this;
}
/**
* This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input
* @param args
* @return
*/
public Builder<I,O> orderedOutputNodes(String... args) {
Preconditions.checkArgument(args != null && args.length > 0, "OutputNodes should contain at least 1 element");
orderedOutputNodes = args;
return this;
}
/**
* This method defines list of graph nodes to be extracted after feed-forward pass and used as OutputAdapter input
* @param args
* @return
*/
public Builder<I,O> orderedOutputNodes(@NonNull List<String> args) {
Preconditions.checkArgument(args.size() > 0, "OutputNodes should contain at least 1 element");
orderedOutputNodes = args.toArray(new String[args.size()]);
return this;
}
/**
* This method allows to configure HTTP port used for serving
*
* PLEASE NOTE: port must be free and be in range regular TCP/IP ports range
* @param port
* @return
*/
public Builder<I,O> port(int port) {
this.port = port;
return this;
}
/**
* This method builds SameDiffJsonModelServer instance
* @return
*/
public SameDiffJsonModelServer<I, O> build() {
if (inferenceAdapter == null) {
if (inputAdapter != null && outputAdapter != null) {
inferenceAdapter = new InferenceAdapter<I, O>() {
@Override
public MultiDataSet apply(I input) {
return inputAdapter.apply(input);
}
@Override
public O apply(INDArray... outputs) {
return outputAdapter.apply(outputs);
}
};
} else
throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
}
return new SameDiffJsonModelServer<I,O>(sameDiff, inferenceAdapter, serializer, deserializer, port, orderedInputNodes, orderedOutputNodes);
}
}
}

View File

@ -0,0 +1,30 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.serving;
import javax.servlet.Servlet;
/**
* This interface describes Servlet interface extension, suited for ND4J/DL4J model serving
* @param <I>
* @param <O>
*
* @author raver119@gmail.com
*/
public interface ModelServingServlet<I, O> extends Servlet {
//
}

View File

@ -0,0 +1,206 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.serving;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.HttpMethod;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.LinkedHashMap;
import static javax.ws.rs.core.MediaType.APPLICATION_JSON;
/**
* This servlet provides SameDiff model serving capabilities
*
* @param <I>
* @param <O>
*
* @author raver119@gmail.com
*/
@NoArgsConstructor
@AllArgsConstructor
@Slf4j
@Builder
public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
protected SameDiff sdModel;
protected JsonSerializer<O> serializer;
protected JsonDeserializer<I> deserializer;
protected InferenceAdapter<I, O> inferenceAdapter;
protected String[] orderedInputNodes;
protected String[] orderedOutputNodes;
protected final static String SERVING_ENDPOINT = "/v1/serving";
protected final static String LISTING_ENDPOINT = "/v1";
protected final static int PAYLOAD_SIZE_LIMIT = 10 * 1024; // TODO: should be customizable
protected SameDiffServlet(@NonNull InferenceAdapter<I, O> inferenceAdapter, @NonNull JsonSerializer<O> serializer, @NonNull JsonDeserializer<I> deserializer){
this.serializer = serializer;
this.deserializer = deserializer;
this.inferenceAdapter = inferenceAdapter;
}
@Override
public void init(ServletConfig servletConfig) throws ServletException {
//
}
@Override
public ServletConfig getServletConfig() {
return null;
}
@Override
public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws ServletException, IOException {
// we'll parse request here, and do model serving
val httpRequest = (HttpServletRequest) servletRequest;
val httpResponse = (HttpServletResponse) servletResponse;
if (httpRequest.getMethod().equals(HttpMethod.GET)) {
doGet(httpRequest, httpResponse);
}
else if (httpRequest.getMethod().equals(HttpMethod.POST)) {
doPost(httpRequest, httpResponse);
}
}
protected void sendError(String uri, HttpServletResponse response) throws IOException {
val msg = "Requested endpoint [" + uri + "] not found";
response.setStatus(404, msg);
response.sendError(404, msg);
}
protected void sendBadContentType(String actualContentType, HttpServletResponse response) throws IOException {
val msg = "Content type [" + actualContentType + "] not supported";
response.setStatus(415, msg);
response.sendError(415, msg);
}
protected boolean validateRequest(HttpServletRequest request, HttpServletResponse response)
throws IOException{
val contentType = request.getContentType();
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
sendBadContentType(contentType, response);
int contentLength = request.getContentLength();
if (contentLength > PAYLOAD_SIZE_LIMIT) {
response.sendError(500, "Payload size limit violated!");
}
return false;
}
return true;
}
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException {
val processor = new ServingProcessor();
String processorReturned = "";
String path = request.getPathInfo();
if (path.equals(LISTING_ENDPOINT)) {
val contentType = request.getContentType();
if (!StringUtils.equals(contentType, APPLICATION_JSON)) {
sendBadContentType(contentType, response);
}
processorReturned = processor.listEndpoints();
}
else {
sendError(request.getRequestURI(), response);
}
try {
val out = response.getWriter();
out.write(processorReturned);
} catch (IOException e) {
log.error(e.getMessage());
}
}
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws IOException {
val processor = new ServingProcessor();
String processorReturned = "";
String path = request.getPathInfo();
if (path.equals(SERVING_ENDPOINT)) {
val contentType = request.getContentType();
/*Preconditions.checkArgument(StringUtils.equals(contentType, APPLICATION_JSON),
"Content type is " + contentType);*/
if (validateRequest(request,response)) {
val stream = request.getInputStream();
val bufferedReader = new BufferedReader(new InputStreamReader(stream));
char[] charBuffer = new char[128];
int bytesRead = -1;
val buffer = new StringBuilder();
while ((bytesRead = bufferedReader.read(charBuffer)) > 0) {
buffer.append(charBuffer, 0, bytesRead);
}
val requestString = buffer.toString();
val mds = inferenceAdapter.apply(deserializer.deserialize(requestString));
val map = new LinkedHashMap<String, INDArray>();
// optionally define placeholders with names provided in server constructor
if (orderedInputNodes != null && orderedInputNodes.length > 0) {
int cnt = 0;
for (val n : orderedInputNodes)
map.put(n, mds.getFeatures(cnt++));
}
val output = sdModel.exec(map, orderedOutputNodes);
val arrays = new INDArray[output.size()];
// now we need to get ordered output arrays, as specified in server constructor
int cnt = 0;
for (val n : orderedOutputNodes)
arrays[cnt++] = output.get(n);
// process result
val result = inferenceAdapter.apply(arrays);
processorReturned = serializer.serialize(result);
}
} else {
// we return error otherwise
sendError(request.getRequestURI(), response);
}
try {
val out = response.getWriter();
out.write(processorReturned);
} catch (IOException e) {
log.error(e.getMessage());
}
}
@Override
public String getServletInfo() {
return null;
}
@Override
public void destroy() {
//
}
}

View File

@ -0,0 +1,14 @@
package org.nd4j.remote.serving;
public class ServingProcessor {
public String listEndpoints() {
String retVal = "/v1/ \n/v1/serving/";
return retVal;
}
public String processModel(String body) {
String response = null; //"Not implemented";
return response;
}
}

View File

@ -0,0 +1,271 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.Test;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.remote.clients.JsonRemoteInference;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.helpers.House;
import org.nd4j.remote.helpers.HouseToPredictedPriceAdapter;
import org.nd4j.remote.helpers.PredictedPrice;
import org.nd4j.remote.clients.serde.impl.FloatArraySerde;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import static org.junit.Assert.*;
@Slf4j
public class SameDiffJsonModelServerTest {
@Test
public void basicServingTest_1() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.sdModel(sd)
.port(18080)
.build();
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:18080/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
val timeStart = System.currentTimeMillis();
price = client.predict(house);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
assertNotNull(price);
assertEquals((float) district + 1.0f, price.getPrice(), 1e-5);
server.stop();
}
@Test
public void testDeserialization_1() {
String request = "{\"bedrooms\":3,\"area\":100,\"district\":2,\"bathrooms\":2}";
val deserializer = new House.HouseDeserializer();
val result = deserializer.deserialize(request);
assertEquals(2, result.getDistrict());
assertEquals(100, result.getArea());
assertEquals(2, result.getBathrooms());
assertEquals(3, result.getBedrooms());
}
@Test
public void testDeserialization_2() {
String request = "{\"price\":1}";
val deserializer = new PredictedPrice.PredictedPriceDeserializer();
val result = deserializer.deserialize(request);
assertEquals(1.0, result.getPrice(), 1e-4);
}
@Test
public void testDeserialization_3() {
float[] data = {0.0f, 0.1f, 0.2f};
val serialized = new FloatArraySerde().serialize(data);
val deserialized = new FloatArraySerde().deserialize(serialized);
assertArrayEquals(data, deserialized, 1e-5f);
}
@Test(expected = NullPointerException.class)
public void negativeServingTest_1() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(null)
.sdModel(sd)
.port(18080)
.build();
}
@Test(expected = NullPointerException.class)
public void negativeServingTest_2() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.sdModel(sd)
.port(18080)
.build();
}
@Test(expected = IOException.class)
public void negativeServingTest_3() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.sdModel(sd)
.port(18080)
.build();
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
@Override
public PredictedPrice deserialize(String json) {
return null;
}
})
.endpointAddress("http://localhost:18080/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
// warmup
PredictedPrice price = client.predict(house);
server.stop();
}
@Test
public void asyncServingTest() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.sdModel(sd)
.port(18080)
.build();
server.start();
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new PredictedPrice.PredictedPriceDeserializer())
.endpointAddress("http://localhost:18080/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
val timeStart = System.currentTimeMillis();
Future<PredictedPrice> price = client.predictAsync(house);
assertNotNull(price);
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
server.stop();
}
@Test
public void negativeAsyncTest() throws Exception {
val sd = SameDiff.create();
val sdVariable = sd.placeHolder("input", DataType.INT, 4);
val result = sdVariable.add(1.0);
val total = result.mean("total", Integer.MAX_VALUE);
val server = new SameDiffJsonModelServer.Builder<House, PredictedPrice>()
.outputSerializer(new PredictedPrice.PredictedPriceSerializer())
.inputDeserializer(new House.HouseDeserializer())
.inferenceAdapter(new HouseToPredictedPriceAdapter())
.orderedInputNodes(new String[]{"input"})
.orderedOutputNodes(new String[]{"total"})
.sdModel(sd)
.port(18080)
.build();
server.start();
// Fake deserializer to test failure
val client = JsonRemoteInference.<House, PredictedPrice>builder()
.inputSerializer(new House.HouseSerializer())
.outputDeserializer(new JsonDeserializer<PredictedPrice>() {
@Override
public PredictedPrice deserialize(String json) {
return null;
}
})
.endpointAddress("http://localhost:18080/v1/serving")
.build();
int district = 2;
House house = House.builder().area(100).bathrooms(2).bedrooms(3).district(district).build();
val timeStart = System.currentTimeMillis();
try {
Future<PredictedPrice> price = client.predictAsync(house);
assertNotNull(price);
assertEquals((float) district + 1.0f, price.get().getPrice(), 1e-5);
val timeStop = System.currentTimeMillis();
log.info("Time spent: {} ms", timeStop - timeStart);
} catch (ExecutionException e) {
assertTrue(e.getMessage().contains("Deserialization failed"));
}
server.stop();
}
}

View File

@ -0,0 +1,116 @@
package org.nd4j.remote;
import lombok.val;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.HttpClientBuilder;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.adapters.InferenceAdapter;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class SameDiffServletTest {
private SameDiffJsonModelServer server;
@Before
public void setUp() throws Exception {
server = new SameDiffJsonModelServer.Builder<String, String>()
.sdModel(SameDiff.create())
.port(8080)
.inferenceAdapter(new InferenceAdapter<String, String>() {
@Override
public MultiDataSet apply(String input) {
return null;
}
@Override
public String apply(INDArray... nnOutput) {
return null;
}
})
.outputSerializer(new JsonSerializer<String>() {
@Override
public String serialize(String o) {
return "";
}
})
.inputDeserializer(new JsonDeserializer<String>() {
@Override
public String deserialize(String json) {
return "";
}
})
.orderedOutputNodes(new String[]{"output"})
.build();
server.start();
//server.join();
}
@After
public void tearDown() throws Exception {
server.stop();
}
@Test
public void getEndpoints() throws IOException {
val request = new HttpGet( "http://localhost:8080/v1" );
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(200, response.getStatusLine().getStatusCode());
}
@Test
public void testContentTypeGet() throws IOException {
val request = new HttpGet( "http://localhost:8080/v1" );
request.setHeader("Content-type", "text/plain");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(415, response.getStatusLine().getStatusCode());
}
@Test
public void testContentTypePost() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving");
request.setHeader("Content-type", "text/plain");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(415, response.getStatusLine().getStatusCode());
}
@Test
public void postForServing() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving");
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(500, response.getStatusLine().getStatusCode());
}
@Test
public void testNotFoundPost() throws Exception {
val request = new HttpPost("http://localhost:8080/v1/serving/some");
request.setHeader("Content-type", "application/json");
val response = HttpClientBuilder.create().build().execute( request );
assertEquals(404, response.getStatusLine().getStatusCode());
}
@Test
public void testNotFoundGet() throws Exception {
val requestGet = new HttpGet( "http://localhost:8080/v1/not_found" );
requestGet.setHeader("Content-type", "application/json");
val responseGet = HttpClientBuilder.create().build().execute( requestGet );
assertEquals(404, responseGet.getStatusLine().getStatusCode());
}
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.helpers;
import com.google.gson.Gson;
import lombok.*;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class House {
private int district;
private int bedrooms;
private int bathrooms;
private int area;
public static class HouseSerializer implements JsonSerializer<House> {
@Override
public String serialize(@NonNull House o) {
return new Gson().toJson(o);
}
}
public static class HouseDeserializer implements JsonDeserializer<House> {
@Override
public House deserialize(@NonNull String json) {
return new Gson().fromJson(json, House.class);
}
}
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.helpers;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.adapters.InferenceAdapter;
@Slf4j
public class HouseToPredictedPriceAdapter implements InferenceAdapter<House, PredictedPrice> {
@Override
public MultiDataSet apply(@NonNull House input) {
// we just create vector array with shape[4] and assign it's value to the district value
return new MultiDataSet(Nd4j.create(DataType.FLOAT, 4).assign(input.getDistrict()), null);
}
@Override
public PredictedPrice apply(INDArray... nnOutput) {
return new PredictedPrice(nnOutput[0].getFloat(0));
}
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.remote.helpers;
import com.google.gson.Gson;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class PredictedPrice {
private float price;
public static class PredictedPriceSerializer implements JsonSerializer<PredictedPrice> {
@Override
public String serialize(@NonNull PredictedPrice o) {
return new Gson().toJson(o);
}
}
public static class PredictedPriceDeserializer implements JsonDeserializer<PredictedPrice> {
@Override
public PredictedPrice deserialize(@NonNull String json) {
return new Gson().fromJson(json, PredictedPrice.class);
}
}
}

View File

@ -0,0 +1,90 @@
package org.nd4j.remote.serde;
import lombok.val;
import org.junit.Test;
import org.nd4j.remote.clients.serde.impl.*;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class BasicSerdeTests {
private final static DoubleArraySerde doubleArraySerde = new DoubleArraySerde();
private final static FloatArraySerde floatArraySerde = new FloatArraySerde();
private final static StringSerde stringSerde = new StringSerde();
private final static IntegerSerde integerSerde = new IntegerSerde();
private final static FloatSerde floatSerde = new FloatSerde();
private final static DoubleSerde doubleSerde = new DoubleSerde();
private final static BooleanSerde booleanSerde = new BooleanSerde();
@Test
public void testStringSerde_1() {
val jvmString = "String with { strange } elements";
val serialized = stringSerde.serialize(jvmString);
val deserialized = stringSerde.deserialize(serialized);
assertEquals(jvmString, deserialized);
}
@Test
public void testFloatArraySerDe_1() {
val jvmArray = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
val serialized = floatArraySerde.serialize(jvmArray);
val deserialized = floatArraySerde.deserialize(serialized);
assertArrayEquals(jvmArray, deserialized, 1e-5f);
}
@Test
public void testDoubleArraySerDe_1() {
val jvmArray = new double[] {1.0, 2.0, 3.0, 4.0, 5.0};
val serialized = doubleArraySerde.serialize(jvmArray);
val deserialized = doubleArraySerde.deserialize(serialized);
assertArrayEquals(jvmArray, deserialized, 1e-5);
}
@Test
public void testFloatSerde_1() {
val f = 119.f;
val serialized = floatSerde.serialize(f);
val deserialized = floatSerde.deserialize(serialized);
assertEquals(f, deserialized, 1e-5f);
}
@Test
public void testDoubleSerde_1() {
val d = 119.;
val serialized = doubleSerde.serialize(d);
val deserialized = doubleSerde.deserialize(serialized);
assertEquals(d, deserialized, 1e-5);
}
@Test
public void testIntegerSerde_1() {
val f = 119;
val serialized = integerSerde.serialize(f);
val deserialized = integerSerde.deserialize(serialized);
assertEquals(f, deserialized.intValue());
}
@Test
public void testBooleanSerde_1() {
val f = true;
val serialized = booleanSerde.serialize(f);
val deserialized = booleanSerde.deserialize(serialized);
assertEquals(f, deserialized);
}
}

View File

@ -0,0 +1,48 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.eclipse.jetty" level="WARN" />
<logger name="org.apache.catalina.core" level="WARN" />
<logger name="org.springframework" level="WARN" />
<logger name="org.nd4j" level="DEBUG" />
<logger name="org.deeplearning4j" level="INFO" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

35
nd4j/nd4j-remote/pom.xml Normal file
View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<packaging>pom</packaging>
<modules>
<module>nd4j-json-client</module>
<module>nd4j-grpc-client</module>
<module>nd4j-json-server</module>
</modules>
<parent>
<groupId>org.nd4j</groupId>
<artifactId>nd4j</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>nd4j-remote</artifactId>
<version>1.0.0-SNAPSHOT</version>
<name>nd4j-remote</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<profiles>
<profile>
<id>testresources</id>
</profile>
</profiles>
</project>

View File

@ -33,7 +33,6 @@
<module>nd4j-camel-routes</module> <module>nd4j-camel-routes</module>
<module>nd4j-gson</module> <module>nd4j-gson</module>
<module>nd4j-arrow</module> <module>nd4j-arrow</module>
<module>nd4j-grpc</module>
</modules> </modules>
<profiles> <profiles>

View File

@ -62,6 +62,7 @@
<module>nd4j-parameter-server-parent</module> <module>nd4j-parameter-server-parent</module>
<module>nd4j-uberjar</module> <module>nd4j-uberjar</module>
<module>nd4j-tensorflow</module> <module>nd4j-tensorflow</module>
<module>nd4j-remote</module>
</modules> </modules>
<dependencyManagement> <dependencyManagement>