From dcc2baa6766fb5236a0afae7984c36cddcb80903 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 30 Aug 2019 14:35:27 +1000 Subject: [PATCH] Version upgrades (#199) * DataVec fixes for Jackson version upgrade Signed-off-by: AlexDBlack * DL4J jackson updates + databind version 2.9.9.3 Signed-off-by: AlexDBlack * Shade snakeyaml along with jackson Signed-off-by: AlexDBlack * Version fix Signed-off-by: AlexDBlack * Switch DataVec legacy JSON format handling to mixins Signed-off-by: AlexDBlack * Next set of fixes Signed-off-by: AlexDBlack * Cleanup for legacy JSON mapping Signed-off-by: AlexDBlack * Upgrade commons compress to 1.18; small test fix Signed-off-by: AlexDBlack * New Jackson backward compatibility for DL4J - Round 1 Signed-off-by: AlexDBlack * New Jackson backward compatibility for DL4J - Round 2 Signed-off-by: AlexDBlack * More fixes, all but legacy custom passing Signed-off-by: AlexDBlack * Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format Signed-off-by: AlexDBlack * Legacy deserialization cleanup Signed-off-by: AlexDBlack * Small amount of polish - legacy JSON Signed-off-by: AlexDBlack * Upgrade guava version Signed-off-by: AlexDBlack * IEvaluation legacy format deserialization fix Signed-off-by: AlexDBlack * Upgrade play version to 2.7.3 Signed-off-by: AlexDBlack * Update nd4j-parameter-server-status to new Play API Signed-off-by: AlexDBlack * Update DL4J UI for new play version Signed-off-by: AlexDBlack * More play framework updates Signed-off-by: AlexDBlack * Small fixes Signed-off-by: AlexDBlack * Remove Spark 1/2 adapter code from DataVec Signed-off-by: AlexDBlack * datavec-spark dependency cleanup Signed-off-by: AlexDBlack * DL4J spark updates, pt 1 Signed-off-by: AlexDBlack * DL4J spark updates, pt 2 Signed-off-by: AlexDBlack * DL4J spark updates, pt 3 Signed-off-by: AlexDBlack * DL4J spark updates, pt 4 Signed-off-by: AlexDBlack * Test fix Signed-off-by: AlexDBlack * Another fix Signed-off-by: AlexDBlack * Breeze upgrade, dependency cleanup Signed-off-by: AlexDBlack * Add Scala 2.12 version to pom.xml Signed-off-by: AlexDBlack * change-scala-versions.sh - add scala 2.12, remove 2.10 Signed-off-by: AlexDBlack * Move Spark version properties to parent pom (now that only one spark version is supported) Signed-off-by: AlexDBlack * DataVec Play fixes Signed-off-by: AlexDBlack * datavec play dependency fixes Signed-off-by: AlexDBlack * Clean up old spark/jackson stuff Signed-off-by: AlexDBlack * Cleanup jackson unused dependencies Signed-off-by: AlexDBlack * Dropping redundant dependency Signed-off-by: Alexander Stoyakin * Removed scalaxy dependency Signed-off-by: Alexander Stoyakin * DataVec fixes for Jackson version upgrade Signed-off-by: AlexDBlack * DL4J jackson updates + databind version 2.9.9.3 Signed-off-by: AlexDBlack * Shade snakeyaml along with jackson Signed-off-by: AlexDBlack * Version fix Signed-off-by: AlexDBlack * Switch DataVec legacy JSON format handling to mixins Signed-off-by: AlexDBlack * Next set of fixes Signed-off-by: AlexDBlack * Cleanup for legacy JSON mapping Signed-off-by: AlexDBlack * Upgrade commons compress to 1.18; small test fix Signed-off-by: AlexDBlack * New Jackson backward compatibility for DL4J - Round 1 Signed-off-by: AlexDBlack * New Jackson backward compatibility for DL4J - Round 2 Signed-off-by: AlexDBlack * More fixes, all but legacy custom passing Signed-off-by: AlexDBlack * Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format Signed-off-by: AlexDBlack * Legacy deserialization cleanup Signed-off-by: AlexDBlack * Small amount of polish - legacy JSON Signed-off-by: AlexDBlack * Upgrade guava version Signed-off-by: AlexDBlack * IEvaluation legacy format deserialization fix Signed-off-by: AlexDBlack * Upgrade play version to 2.7.3 Signed-off-by: AlexDBlack * Update nd4j-parameter-server-status to new Play API Signed-off-by: AlexDBlack * Update DL4J UI for new play version Signed-off-by: AlexDBlack * More play framework updates Signed-off-by: AlexDBlack * Small fixes Signed-off-by: AlexDBlack * Remove Spark 1/2 adapter code from DataVec Signed-off-by: AlexDBlack * datavec-spark dependency cleanup Signed-off-by: AlexDBlack * DL4J spark updates, pt 1 Signed-off-by: AlexDBlack * DL4J spark updates, pt 2 Signed-off-by: AlexDBlack * DL4J spark updates, pt 3 Signed-off-by: AlexDBlack * DL4J spark updates, pt 4 Signed-off-by: AlexDBlack * Test fix Signed-off-by: AlexDBlack * Another fix Signed-off-by: AlexDBlack * Breeze upgrade, dependency cleanup Signed-off-by: AlexDBlack * Add Scala 2.12 version to pom.xml Signed-off-by: AlexDBlack * change-scala-versions.sh - add scala 2.12, remove 2.10 Signed-off-by: AlexDBlack * Move Spark version properties to parent pom (now that only one spark version is supported) Signed-off-by: AlexDBlack * DataVec Play fixes Signed-off-by: AlexDBlack * datavec play dependency fixes Signed-off-by: AlexDBlack * Clean up old spark/jackson stuff Signed-off-by: AlexDBlack * Cleanup jackson unused dependencies Signed-off-by: AlexDBlack * Add shaded guava Signed-off-by: AlexDBlack * Dropping redundant dependency Signed-off-by: Alexander Stoyakin * Removed scalaxy dependency Signed-off-by: Alexander Stoyakin * Ensure not possible to import pre-shaded classes, and remove direct guava dependencies in favor of shaded Signed-off-by: AlexDBlack * ND4J Shaded guava import fixes Signed-off-by: AlexDBlack * DataVec and DL4J guava shading Signed-off-by: AlexDBlack * Arbiter, RL4J fixes Signed-off-by: AlexDBlack * Build fixed Signed-off-by: Alexander Stoyakin * Fix dependency Signed-off-by: Alexander Stoyakin * Fix bad merge Signed-off-by: AlexDBlack * Jackson shading fixes Signed-off-by: AlexDBlack * Set play secret, datavec-spark-inference-server Signed-off-by: AlexDBlack * Fix for datavec-spark-inference-server Signed-off-by: AlexDBlack * Arbiter fixes Signed-off-by: AlexDBlack * Arbiter fixes Signed-off-by: AlexDBlack * Small test fix Signed-off-by: AlexDBlack --- arbiter/arbiter-core/pom.xml | 5 + .../distribution/LogUniformDistribution.java | 2 +- .../runner/BaseOptimizationRunner.java | 2 +- .../runner/LocalOptimizationRunner.java | 6 +- .../optimize/serde/jackson/JsonMapper.java | 14 +- .../optimize/serde/jackson/YamlMapper.java | 1 + .../arbiter/optimize/TestJson.java | 1 + arbiter/arbiter-deeplearning4j/pom.xml | 6 + .../arbiter/layers/BaseLayerSpace.java | 2 +- arbiter/arbiter-ui/pom.xml | 6 - .../arbiter/ui/misc/JsonMapper.java | 10 +- change-scala-versions.sh | 34 +- change-spark-versions.sh | 85 --- .../org/datavec/api/transform/Transform.java | 4 +- .../api/transform/TransformProcess.java | 11 + .../api/transform/analysis/DataAnalysis.java | 12 + .../analysis/SequenceDataAnalysis.java | 13 +- .../analysis/columns/ColumnAnalysis.java | 4 +- .../api/transform/condition/Condition.java | 4 +- .../datavec/api/transform/filter/Filter.java | 4 +- .../transform/metadata/ColumnMetaData.java | 4 +- .../ops/DispatchWithConditionOp.java | 4 +- .../transform/rank/CalculateSortedRank.java | 4 +- .../datavec/api/transform/schema/Schema.java | 31 +- .../sequence/SequenceComparator.java | 4 +- .../api/transform/sequence/SequenceSplit.java | 4 +- .../sequence/window/WindowFunction.java | 4 +- .../api/transform/serde/JsonMappers.java | 239 +------- .../legacy/GenericLegacyDeserializer.java | 41 -- .../serde/legacy/LegacyJsonFormat.java | 267 +++++++++ .../serde/legacy/LegacyMappingHelper.java | 535 ------------------ .../stringreduce/IStringReducer.java | 4 +- .../api/util/ndarray/RecordConverter.java | 2 +- .../datavec/api/writable/ByteWritable.java | 2 +- .../datavec/api/writable/DoubleWritable.java | 2 +- .../datavec/api/writable/FloatWritable.java | 2 +- .../org/datavec/api/writable/IntWritable.java | 2 +- .../datavec/api/writable/LongWritable.java | 2 +- .../org/datavec/api/writable/Writable.java | 4 +- .../writable/batch/NDArrayRecordBatch.java | 2 +- .../comparator/WritableComparator.java | 5 +- .../datavec/api/split/InputSplitTests.java | 2 +- .../api/split/parittion/PartitionerTests.java | 2 +- .../api/transform/schema/TestJsonYaml.java | 5 +- .../api/writable/RecordConverterTest.java | 2 +- datavec/datavec-arrow/pom.xml | 35 -- .../recordreader/BaseImageRecordReader.java | 2 +- .../image/serde/LegacyImageMappingHelper.java | 35 -- .../image/transform/ImageTransform.java | 6 +- datavec/datavec-hadoop/pom.xml | 5 - .../reader/TestMapFileRecordReader.java | 2 +- .../TestMapFileRecordReaderMultipleParts.java | 2 +- ...ileRecordReaderMultiplePartsSomeEmpty.java | 2 +- .../writer/TestMapFileRecordWriter.java | 2 +- ...JoinFromCoGroupFlatMapFunctionAdapter.java | 2 +- .../datavec-spark-inference-server/pom.xml | 58 +- .../transform/CSVSparkTransformServer.java | 81 +-- .../datavec/spark/transform/FunctionUtil.java | 41 -- .../transform/ImageSparkTransformServer.java | 62 +- .../spark/transform/SparkTransformServer.java | 21 +- .../src/main/resources/application.conf | 350 ++++++++++++ datavec/datavec-spark/pom.xml | 169 +----- .../functions/FlatMapFunctionAdapter.java | 29 - .../datavec/spark/transform/DataFrames.java | 48 +- .../spark/transform/Normalization.java | 59 +- .../analysis/SequenceFlatMapFunction.java | 10 +- .../SequenceFlatMapFunctionAdapter.java | 36 -- ...ExecuteJoinFromCoGroupFlatMapFunction.java | 90 ++- ...JoinFromCoGroupFlatMapFunctionAdapter.java | 119 ---- .../join/FilterAndFlattenJoinedValues.java | 41 +- .../FilterAndFlattenJoinedValuesAdapter.java | 71 --- .../sparkfunction/SequenceToRows.java | 56 +- .../sparkfunction/SequenceToRowsAdapter.java | 87 --- .../transform/SequenceSplitFunction.java | 14 +- .../SequenceSplitFunctionAdapter.java | 41 -- .../SparkTransformProcessFunction.java | 19 +- .../SparkTransformProcessFunctionAdapter.java | 45 -- .../transform/BaseFlatMapFunctionAdaptee.java | 41 -- .../spark/transform/DataRowsFacade.java | 42 -- .../transform/BaseFlatMapFunctionAdaptee.java | 42 -- .../spark/transform/DataRowsFacade.java | 43 -- .../spark/storage/TestSparkStorageUtils.java | 2 +- .../spark/transform/DataFramesTests.java | 16 +- .../spark/transform/NormalizationTests.java | 34 +- .../config/DL4JSystemProperties.java | 12 - deeplearning4j/deeplearning4j-core/pom.xml | 12 - .../RecordReaderDataSetiteratorTest.java | 2 +- .../RecordReaderMultiDataSetIteratorTest.java | 2 +- .../deeplearning4j/nn/dtypes/DTypeTests.java | 6 +- .../nn/layers/recurrent/TestRnnLayers.java | 4 +- .../plot/BarnesHutTsneTest.java | 2 +- .../regressiontest/RegressionTest100a.java | 79 +-- .../iterator/MultipleEpochsIterator.java | 4 +- .../FileSplitParallelDataSetIterator.java | 2 +- .../deeplearning4j/plot/BarnesHutTsne.java | 2 +- .../java/org/deeplearning4j/plot/Tsne.java | 2 +- .../deeplearning4j-modelimport/pom.xml | 6 + .../pom.xml | 20 - .../nearestneighbor/server/FunctionUtil.java | 41 -- .../server/NearestNeighborsServer.java | 53 +- .../nearestneighbor-core/pom.xml | 7 + .../clustering/info/ClusterSetInfo.java | 4 +- .../clustering/quadtree/QuadTree.java | 2 +- .../clustering/randomprojection/RPUtils.java | 2 +- .../clustering/sptree/SpTree.java | 2 +- .../clustering/kdtree/KDTreeTest.java | 2 +- .../clustering/sptree/SPTreeTest.java | 2 +- .../clustering/vptree/VpTreeNodeTest.java | 3 - .../text/corpora/sentiwordnet/SWN3.java | 2 +- .../models/WordVectorSerializerTest.java | 4 +- .../models/word2vec/Word2VecTests.java | 4 +- .../inmemory/InMemoryLookupTable.java | 2 +- .../reader/impl/BasicModelUtils.java | 2 +- .../wordvectors/WordVectorsImpl.java | 2 +- .../models/glove/count/CountMap.java | 2 +- .../paragraphvectors/ParagraphVectors.java | 2 +- .../sequencevectors/SequenceVectors.java | 4 +- .../sequence/SequenceElement.java | 2 +- .../text/invertedindex/InvertedIndex.java | 2 +- .../wordvectors/WordVectorsImplTest.java | 2 +- .../deeplearning4j/eval/ConfusionMatrix.java | 4 +- .../eval/curves/PrecisionRecallCurve.java | 2 +- .../deeplearning4j/eval/curves/RocCurve.java | 2 +- .../conf/ComputationGraphConfiguration.java | 22 + .../nn/conf/InputPreProcessor.java | 4 +- .../nn/conf/MultiLayerConfiguration.java | 22 + .../nn/conf/NeuralNetConfiguration.java | 94 +-- .../nn/conf/graph/AttentionVertex.java | 2 +- .../nn/conf/graph/GraphVertex.java | 4 +- .../nn/conf/inputs/InputType.java | 58 +- .../deeplearning4j/nn/conf/layers/Layer.java | 4 +- .../nn/conf/layers/Upsampling2D.java | 2 +- .../nn/conf/layers/misc/FrozenLayer.java | 2 - .../ReconstructionDistribution.java | 4 +- .../conf/serde/BaseNetConfigDeserializer.java | 82 ++- ...utationGraphConfigurationDeserializer.java | 11 + .../conf/serde/FrozenLayerDeserializer.java | 58 -- .../nn/conf/serde/JsonMappers.java | 148 +---- .../MultiLayerConfigurationDeserializer.java | 55 +- .../LegacyIntArrayDeserializer.java | 2 +- .../conf/serde/legacy/LegacyJsonFormat.java | 175 ++++++ .../LegacyGraphVertexDeserializer.java | 94 --- .../LegacyGraphVertexDeserializerHelper.java | 28 - .../legacyformat/LegacyLayerDeserializer.java | 113 ---- .../LegacyLayerDeserializerHelper.java | 28 - .../LegacyPreprocessorDeserializer.java | 83 --- .../LegacyPreprocessorDeserializerHelper.java | 28 - ...econstructionDistributionDeserializer.java | 70 --- ...ructionDistributionDeserializerHelper.java | 28 - .../nn/workspace/LayerWorkspaceMgr.java | 2 +- .../listeners/CheckpointListener.java | 2 +- .../listeners/PerformanceListener.java | 2 +- .../solvers/accumulation/EncodingHandler.java | 2 +- .../deeplearning4j/util/ModelSerializer.java | 2 +- .../deeplearning4j-aws/pom.xml | 27 - .../pom.xml | 16 - .../EarlyStoppingParallelTrainer.java | 2 +- .../observers/BasicInferenceObservable.java | 2 +- .../spark/dl4j-spark-nlp-java8/pom.xml | 36 -- .../functions/VocabRddFunctionFlat.java | 91 ++- .../word2vec/FirstIterationFunction.java | 240 +++++++- .../FirstIterationFunctionAdapter.java | 265 --------- .../word2vec/SecondIterationFunction.java | 26 +- .../spark/dl4j-spark-parameterserver/pom.xml | 36 -- .../functions/SharedFlatMapDataSet.java | 21 +- .../functions/SharedFlatMapMultiDataSet.java | 23 +- .../functions/SharedFlatMapPaths.java | 20 +- .../functions/SharedFlatMapPathsMDS.java | 21 +- .../api/worker/ExecuteWorkerFlatMap.java | 37 +- .../ExecuteWorkerMultiDataSetFlatMap.java | 38 +- .../api/worker/ExecuteWorkerPDSFlatMap.java | 28 +- .../worker/ExecuteWorkerPDSMDSFlatMap.java | 28 +- .../api/worker/ExecuteWorkerPathFlatMap.java | 28 +- .../worker/ExecuteWorkerPathMDSFlatMap.java | 30 +- .../spark/data/BatchDataSetsFunction.java | 37 +- .../spark/data/SplitDataSetsFunction.java | 30 +- ...litDataSetExamplesPairFlatMapFunction.java | 30 +- .../spark/datavec/RDDMiniBatches.java | 27 +- .../HashingBalancedPartitioner.java | 21 +- .../repartition/MapTupleToPairFlatMap.java | 17 +- ...VaeReconstructionProbWithKeyFunction.java} | 6 +- ....java => BaseVaeScoreWithKeyFunction.java} | 14 +- .../IEvaluateMDSFlatMapFunction.java | 37 +- .../IEvaluateMDSPathsFlatMapFunction.java | 34 +- ...VaeReconstructionErrorWithKeyFunction.java | 4 +- ...GVaeReconstructionProbWithKeyFunction.java | 4 +- .../GraphFeedForwardWithKeyFunction.java | 49 +- .../graph/scoring/ScoreExamplesFunction.java | 33 +- .../scoring/ScoreExamplesWithKeyFunction.java | 39 +- .../ScoreFlatMapFunctionCGDataSet.java | 25 +- .../ScoreFlatMapFunctionCGMultiDataSet.java | 25 +- .../evaluation/IEvaluateFlatMapFunction.java | 33 +- .../scoring/FeedForwardWithKeyFunction.java | 32 +- .../scoring/ScoreExamplesFunction.java | 30 +- .../scoring/ScoreExamplesWithKeyFunction.java | 42 +- .../scoring/ScoreFlatMapFunction.java | 36 +- ...VaeReconstructionErrorWithKeyFunction.java | 26 +- .../VaeReconstructionProbWithKeyFunction.java | 28 +- .../ParameterAveragingTrainingMaster.java | 2 +- .../BaseDoubleFlatMapFunctionAdaptee.java | 40 -- .../util/BasePairFlatMapFunctionAdaptee.java | 41 -- .../BaseDoubleFlatMapFunctionAdaptee.java | 42 -- .../util/BasePairFlatMapFunctionAdaptee.java | 43 -- .../deeplearning4j-scaleout/spark/pom.xml | 52 -- .../deeplearning4j-play/pom.xml | 85 --- .../deeplearning4j/ui/api/FunctionType.java | 11 +- .../java/org/deeplearning4j/ui/api/Route.java | 17 +- .../ui/module/tsne/TsneModule.java | 29 +- .../deeplearning4j/ui/play/PlayUIServer.java | 143 ++--- .../ui/play/misc/FunctionUtil.java | 45 -- .../ui/play/staticroutes/Assets.java | 26 +- .../IntegrationTestBaselineGenerator.java | 2 +- .../integration/IntegrationTestRunner.java | 4 +- .../integration/testcases/RNNTestCases.java | 2 +- deeplearning4j/pom.xml | 5 - .../autodiff/listeners/ListenerVariables.java | 2 +- .../checkpoint/CheckpointListener.java | 2 +- .../listeners/records/EvaluationRecord.java | 8 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 10 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 2 +- .../samediff/serde/FlatBuffersMapper.java | 2 +- .../autodiff/validation/OpValidation.java | 6 +- .../org/nd4j/evaluation/BaseEvaluation.java | 12 +- .../classification/ConfusionMatrix.java | 4 +- .../evaluation/custom/CustomEvaluation.java | 2 +- .../nd4j/evaluation/custom/MergeLambda.java | 2 +- .../serde/ConfusionMatrixSerializer.java | 2 +- .../graphmapper/onnx/OnnxGraphMapper.java | 6 +- .../imports/graphmapper/tf/TFGraphMapper.java | 4 +- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 4 +- .../linalg/api/ndarray/BaseSparseNDArray.java | 4 +- .../api/ndarray/BaseSparseNDArrayCOO.java | 6 +- .../api/ndarray/BaseSparseNDArrayCSR.java | 2 +- .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 2 +- .../nd4j/linalg/api/ops/DynamicCustomOp.java | 6 +- .../nd4j/linalg/api/ops/aggregates/Batch.java | 2 +- .../ops/impl/controlflow/compat/Switch.java | 2 +- .../api/ops/impl/reduce/TensorMmul.java | 4 +- .../linalg/api/ops/impl/shape/Transpose.java | 2 +- .../ops/impl/transforms/custom/Choose.java | 4 +- .../java/org/nd4j/linalg/api/shape/Shape.java | 6 +- .../linalg/dataset/BalanceMinibatches.java | 4 +- .../java/org/nd4j/linalg/dataset/DataSet.java | 4 +- .../org/nd4j/linalg/dataset/api/DataSet.java | 2 +- .../linalg/factory/BaseNDArrayFactory.java | 2 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 4 +- .../org/nd4j/linalg/indexing/Indices.java | 4 +- .../nd4j/linalg/indexing/IntervalIndex.java | 2 +- .../nd4j/linalg/indexing/NDArrayIndex.java | 4 +- .../org/nd4j/linalg/indexing/PointIndex.java | 2 +- .../nd4j/linalg/indexing/SpecifiedIndex.java | 2 +- .../linalg/indexing/conditions/Condition.java | 2 +- .../linalg/indexing/functions/Identity.java | 2 +- .../indexing/functions/StableNumber.java | 2 +- .../nd4j/linalg/indexing/functions/Value.java | 2 +- .../nd4j/linalg/indexing/functions/Zero.java | 2 +- .../org/nd4j/jita/handler/MemoryHandler.java | 2 +- .../jita/handler/impl/CudaZeroHandler.java | 4 +- .../nd4j/linalg/jcublas/util/CudaArgs.java | 4 +- .../linalg/api/indexing/IndexingTestsC.java | 1 - nd4j/nd4j-common/pom.xml | 11 +- .../linalg/collection/IntArrayKeyMap.java | 2 +- .../nd4j/linalg/primitives/AtomicDouble.java | 4 +- .../java/org/nd4j/linalg/util/ArrayUtil.java | 4 +- .../nd4j/linalg/util/SynchronizedTable.java | 2 +- .../nd4j/resources/strumpf/ResourceFile.java | 2 +- .../transport/RoutedTransport.java | 2 +- .../v2/transport/impl/AeronUdpTransport.java | 2 +- .../nd4j-parameter-server-status/pom.xml | 10 - .../status/play/StatusServer.java | 102 +--- .../ParameterServerSubscriber.java | 2 +- .../ipc/chunk/InMemoryChunkAccumulator.java | 2 +- nd4j/nd4j-serde/nd4j-arrow/pom.xml | 30 - nd4j/nd4j-serde/nd4j-gson/pom.xml | 3 - .../serde/gson/GsonDeserializationUtils.java | 4 +- nd4j/nd4j-shade/guava/pom.xml | 219 +++++++ nd4j/nd4j-shade/jackson/pom.xml | 131 +++-- nd4j/nd4j-shade/pom.xml | 1 + nd4j/nd4j-shade/protobuf/pom.xml | 21 + nd4j/pom.xml | 5 - nd4s/build.sbt | 2 +- nd4s/pom.xml | 14 +- .../org/nd4s/CollectionLikeNDArray.scala | 5 +- pom.xml | 31 +- rl4j/rl4j-core/pom.xml | 6 + scalnet/pom.xml | 6 + 286 files changed, 2774 insertions(+), 5074 deletions(-) delete mode 100755 change-spark-versions.sh delete mode 100644 datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java create mode 100644 datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java delete mode 100644 datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java delete mode 100644 datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java delete mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java create mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf delete mode 100644 datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java delete mode 100644 datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java delete mode 100644 datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java delete mode 100644 datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java delete mode 100644 datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java delete mode 100644 datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java delete mode 100644 datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java delete mode 100644 datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java delete mode 100644 datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java delete mode 100644 datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java delete mode 100644 datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java delete mode 100644 deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/FrozenLayerDeserializer.java rename deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/{legacyformat => legacy}/LegacyIntArrayDeserializer.java (97%) create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializer.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializerHelper.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializer.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializerHelper.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializer.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializerHelper.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializer.java delete mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializerHelper.java delete mode 100644 deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java rename deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/{BaseVaeReconstructionProbWithKeyFunctionAdapter.java => BaseVaeReconstructionProbWithKeyFunction.java} (87%) rename deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/{BaseVaeScoreWithKeyFunctionAdapter.java => BaseVaeScoreWithKeyFunction.java} (88%) delete mode 100644 deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java delete mode 100644 deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java delete mode 100644 deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java delete mode 100644 deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java delete mode 100644 deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/misc/FunctionUtil.java create mode 100644 nd4j/nd4j-shade/guava/pom.xml diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml index 296af48fd..064dd3ecd 100644 --- a/arbiter/arbiter-core/pom.xml +++ b/arbiter/arbiter-core/pom.xml @@ -72,6 +72,11 @@ test + + joda-time + joda-time + ${jodatime.version} + diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java index 9e50065e6..a9c0933c4 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java @@ -16,7 +16,7 @@ package org.deeplearning4j.arbiter.optimize.distribution; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Getter; import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.exception.NumberIsTooLargeException; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java index fa503ef6d..0e04a130b 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java @@ -16,7 +16,7 @@ package org.deeplearning4j.arbiter.optimize.runner; -import com.google.common.util.concurrent.ListenableFuture; +import org.nd4j.shade.guava.util.concurrent.ListenableFuture; import lombok.AllArgsConstructor; import lombok.Data; import lombok.extern.slf4j.Slf4j; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java index a3992b09a..6982090f1 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java @@ -16,9 +16,9 @@ package org.deeplearning4j.arbiter.optimize.runner; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.MoreExecutors; +import org.nd4j.shade.guava.util.concurrent.ListenableFuture; +import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService; +import org.nd4j.shade.guava.util.concurrent.MoreExecutors; import lombok.Setter; import org.deeplearning4j.arbiter.optimize.api.*; import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java index 9e30a06f6..8cfb07723 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java @@ -43,13 +43,15 @@ public class JsonMapper { mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); yamlMapper = new ObjectMapper(new YAMLFactory()); - mapper.registerModule(new JodaModule()); - mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - mapper.enable(SerializationFeature.INDENT_OUTPUT); - mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + yamlMapper.registerModule(new JodaModule()); + yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + yamlMapper.enable(SerializationFeature.INDENT_OUTPUT); + yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); + yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); } private JsonMapper() {} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java index f10c593e0..5b35220e9 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java @@ -39,6 +39,7 @@ public class YamlMapper { mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); } diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java index 3572b187b..6f1b336bb 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java @@ -59,6 +59,7 @@ public class TestJson { om.enable(SerializationFeature.INDENT_OUTPUT); om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); return om; } diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml index 77f7e34a9..85afe7a6b 100644 --- a/arbiter/arbiter-deeplearning4j/pom.xml +++ b/arbiter/arbiter-deeplearning4j/pom.xml @@ -57,6 +57,12 @@ jackson ${nd4j.version} + + + com.google.code.gson + gson + ${gson.version} + diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java index 77c31707a..0a5e33d27 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java @@ -16,7 +16,7 @@ package org.deeplearning4j.arbiter.layers; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.AccessLevel; import lombok.Data; import lombok.EqualsAndHashCode; diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml index 93e955219..56d1013bf 100644 --- a/arbiter/arbiter-ui/pom.xml +++ b/arbiter/arbiter-ui/pom.xml @@ -107,12 +107,6 @@ - - com.google.guava - guava - ${guava.version} - - org.deeplearning4j arbiter-core diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java index 491756ae7..0ac5ad383 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java @@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.ui.misc; import org.nd4j.shade.jackson.annotation.JsonAutoDetect; +import org.nd4j.shade.jackson.annotation.PropertyAccessor; import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; @@ -45,12 +46,9 @@ public class JsonMapper { mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true); mapper.enable(SerializationFeature.INDENT_OUTPUT); - - mapper.setVisibilityChecker(mapper.getSerializationConfig().getDefaultVisibilityChecker() - .withFieldVisibility(JsonAutoDetect.Visibility.ANY) - .withGetterVisibility(JsonAutoDetect.Visibility.NONE) - .withSetterVisibility(JsonAutoDetect.Visibility.NONE) - .withCreatorVisibility(JsonAutoDetect.Visibility.NONE)); + mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); + mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); return mapper; } diff --git a/change-scala-versions.sh b/change-scala-versions.sh index 5fb40c5cb..8968abbf3 100755 --- a/change-scala-versions.sh +++ b/change-scala-versions.sh @@ -20,9 +20,9 @@ set -e -VALID_VERSIONS=( 2.10 2.11 ) -SCALA_210_VERSION=$(grep -F -m 1 'scala210.version' pom.xml); SCALA_210_VERSION="${SCALA_210_VERSION#*>}"; SCALA_210_VERSION="${SCALA_210_VERSION%<*}"; +VALID_VERSIONS=( 2.11 2.12 ) SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}"; +SCALA_212_VERSION=$(grep -F -m 1 'scala212.version' pom.xml); SCALA_212_VERSION="${SCALA_212_VERSION#*>}"; SCALA_212_VERSION="${SCALA_212_VERSION%<*}"; usage() { echo "Usage: $(basename $0) [-h|--help] @@ -45,19 +45,18 @@ check_scala_version() { exit 1 } - check_scala_version "$TO_VERSION" if [ $TO_VERSION = "2.11" ]; then - FROM_BINARY="_2\.10" + FROM_BINARY="_2\.12" TO_BINARY="_2\.11" - FROM_VERSION=$SCALA_210_VERSION + FROM_VERSION=$SCALA_212_VERSION TO_VERSION=$SCALA_211_VERSION else FROM_BINARY="_2\.11" - TO_BINARY="_2\.10" + TO_BINARY="_2\.12" FROM_VERSION=$SCALA_211_VERSION - TO_VERSION=$SCALA_210_VERSION + TO_VERSION=$SCALA_212_VERSION fi sed_i() { @@ -70,35 +69,24 @@ echo "Updating Scala versions in pom.xml files to Scala $1, from $FROM_VERSION t BASEDIR=$(dirname $0) -#Artifact ids, ending with "_2.10" or "_2.11". Spark, spark-mllib, kafka, etc. +#Artifact ids, ending with "_2.11" or "_2.12". Spark, spark-mllib, kafka, etc. find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \; -#Scala versions, like 2.10 +#Scala versions, like 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \; -#Scala binary versions, like 2.10 +#Scala binary versions, like 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \; -#Scala versions, like scala-library 2.10.6 +#Scala versions, like scala-library 2.11.12 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \; -#Scala maven plugin, 2.10 +#Scala maven plugin, 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \; - -#Edge case for Korean NLP artifact not following conventions: https://github.com/deeplearning4j/deeplearning4j/issues/6306 -#https://github.com/deeplearning4j/deeplearning4j/issues/6306 -if [[ $TO_VERSION == 2.11* ]]; then - sed_i 's/korean-text-scala-2.10<\/artifactId>/korean-text<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml - sed_i 's/4.2.0<\/version>/4.4<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml -else - sed_i 's/korean-text<\/artifactId>/korean-text-scala-2.10<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml - sed_i 's/4.4<\/version>/4.2.0<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml -fi - echo "Done updating Scala versions."; diff --git a/change-spark-versions.sh b/change-spark-versions.sh deleted file mode 100755 index 06a9b4d55..000000000 --- a/change-spark-versions.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env bash - -################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -# This shell script is adapted from Apache Flink (in turn, adapted from Apache Spark) some modifications. - -set -e - -VALID_VERSIONS=( 1 2 ) -SPARK_2_VERSION="2\.1\.0" -SPARK_1_VERSION="1\.6\.3" - -usage() { - echo "Usage: $(basename $0) [-h|--help] -where : - -h| --help Display this help text - valid spark version values : ${VALID_VERSIONS[*]} -" 1>&2 - exit 1 -} - -if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then - usage -fi - -TO_VERSION=$1 - -check_spark_version() { - for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done - echo "Invalid Spark version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 - exit 1 -} - - -check_spark_version "$TO_VERSION" - -if [ $TO_VERSION = "2" ]; then - FROM_BINARY="1" - TO_BINARY="2" - FROM_VERSION=$SPARK_1_VERSION - TO_VERSION=$SPARK_2_VERSION -else - FROM_BINARY="2" - TO_BINARY="1" - FROM_VERSION=$SPARK_2_VERSION - TO_VERSION=$SPARK_1_VERSION -fi - -sed_i() { - sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2" -} - -export -f sed_i - -echo "Updating Spark versions in pom.xml files to Spark $1"; - -BASEDIR=$(dirname $0) - -# 1 -find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ - -exec bash -c "sed_i 's/\(spark.major.version>\)'$FROM_BINARY'<\/spark.major.version>/\1'$TO_BINARY'<\/spark.major.version>/g' {}" \; - -# 1.6.3 -find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ - -exec bash -c "sed_i 's/\(spark.version>\)'$FROM_VERSION'<\/spark.version>/\1'$TO_VERSION'<\/spark.version>/g' {}" \; - -#Spark versions, like xxx_spark_2xxx OR xxx_spark_2xxx -find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ - -exec bash -c "sed_i 's/\(version>.*_spark_\)'$FROM_BINARY'\(.*\)version>/\1'$TO_BINARY'\2version>/g' {}" \; - -echo "Done updating Spark versions."; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java index 44098e6cc..32f8f7cc5 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java @@ -16,7 +16,6 @@ package org.datavec.api.transform; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -27,8 +26,7 @@ import java.util.List; /**A Transform converts an example to another example, or a sequence to another sequence */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.TransformHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Transform extends Serializable, ColumnOp { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java index 24a029179..7c57f4daa 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java @@ -67,6 +67,7 @@ import org.joda.time.DateTimeZone; import org.nd4j.linalg.primitives.Pair; import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import java.io.IOException; import java.io.Serializable; @@ -417,6 +418,16 @@ public class TransformProcess implements Serializable { public static TransformProcess fromJson(String json) { try { return JsonMappers.getMapper().readValue(json, TransformProcess.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + try{ + return JsonMappers.getLegacyMapper().readValue(json, TransformProcess.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (IOException e) { //TODO proper exception message throw new RuntimeException(e); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java index 6b069a9ec..467db70f0 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java @@ -23,12 +23,14 @@ import org.datavec.api.transform.analysis.columns.ColumnAnalysis; import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; +import org.datavec.api.transform.serde.JsonMappers; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.databind.node.ArrayNode; import java.io.IOException; @@ -116,6 +118,16 @@ public class DataAnalysis implements Serializable { public static DataAnalysis fromJson(String json) { try{ return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, DataAnalysis.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (Exception e){ //Legacy format ObjectMapper om = new JsonSerializer().getObjectMapper(); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java index ecc333d2d..6156ead40 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java @@ -21,9 +21,10 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.analysis.columns.ColumnAnalysis; import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis; import org.datavec.api.transform.schema.Schema; +import org.datavec.api.transform.serde.JsonMappers; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import java.io.IOException; import java.util.List; @@ -50,6 +51,16 @@ public class SequenceDataAnalysis extends DataAnalysis { public static SequenceDataAnalysis fromJson(String json){ try{ return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, SequenceDataAnalysis.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (IOException e){ throw new RuntimeException(e); } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java index dd43315d2..c86584ede 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.analysis.columns; import org.datavec.api.transform.ColumnType; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -27,8 +26,7 @@ import java.io.Serializable; * Interface for column analysis */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.ColumnAnalysisHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ColumnAnalysis extends Serializable { long getCountTotal(); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java index 83a96bfcf..6bd5b98ac 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java @@ -18,7 +18,6 @@ package org.datavec.api.transform.condition; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -35,8 +34,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.ConditionHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Condition extends Serializable, ColumnOp { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java index 5aa672f76..16870e9f9 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java @@ -18,7 +18,6 @@ package org.datavec.api.transform.filter; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -33,8 +32,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.FilterHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Filter extends Serializable, ColumnOp { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java index cde86cf90..6889fbf31 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.metadata; import org.datavec.api.transform.ColumnType; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -32,8 +31,7 @@ import java.io.Serializable; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.ColumnMetaDataHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ColumnMetaData extends Serializable, Cloneable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java index 3341e0b6d..8ef67bacb 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java @@ -23,8 +23,8 @@ import org.datavec.api.writable.Writable; import java.util.List; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; +import static org.nd4j.shade.guava.base.Preconditions.checkArgument; +import static org.nd4j.shade.guava.base.Preconditions.checkNotNull; /** * A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition}, diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java index 5bcf81f7a..1e5177c68 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java @@ -23,7 +23,6 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.LongMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.comparator.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonInclude; @@ -50,8 +49,7 @@ import java.util.List; @EqualsAndHashCode(exclude = {"inputSchema"}) @JsonIgnoreProperties({"inputSchema"}) @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.CalculateSortedRankHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public class CalculateSortedRank implements Serializable, ColumnOp { private final String newColumnName; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java index a9e167943..1c16ebcce 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.metadata.*; import org.datavec.api.transform.serde.JsonMappers; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.*; import org.joda.time.DateTimeZone; import org.nd4j.shade.jackson.annotation.*; @@ -29,9 +28,11 @@ import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.datatype.joda.JodaModule; +import java.io.IOException; import java.io.Serializable; import java.util.*; @@ -48,8 +49,7 @@ import java.util.*; */ @JsonIgnoreProperties({"columnNames", "columnNamesIndex"}) @EqualsAndHashCode -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.SchemaHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @Data public class Schema implements Serializable { @@ -358,6 +358,16 @@ public class Schema implements Serializable { public static Schema fromJson(String json) { try{ return JsonMappers.getMapper().readValue(json, Schema.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, Schema.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (Exception e){ //TODO better exceptions throw new RuntimeException(e); @@ -379,21 +389,6 @@ public class Schema implements Serializable { } } - private static Schema fromJacksonString(String str, JsonFactory factory) { - ObjectMapper om = new ObjectMapper(factory); - om.registerModule(new JodaModule()); - om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - om.enable(SerializationFeature.INDENT_OUTPUT); - om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); - try { - return om.readValue(str, Schema.class); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - public static class Builder { List columnMetaData = new ArrayList<>(); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java index e4a09f6e9..a5677d1f5 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.sequence; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -30,8 +29,7 @@ import java.util.List; * Compare the time steps of a sequence */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.SequenceComparatorHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface SequenceComparator extends Comparator>, Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java index bef0b4ccc..3471dbaa3 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.sequence; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -32,8 +31,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.SequenceSplitHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface SequenceSplit extends Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java index e25f85253..1456af8d7 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.sequence.window; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -36,8 +35,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.WindowFunctionHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface WindowFunction extends Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java index 652556ce4..bfa114697 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java @@ -16,44 +16,17 @@ package org.datavec.api.transform.serde; -import lombok.AllArgsConstructor; -import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.datavec.api.io.WritableComparator; -import org.datavec.api.transform.Transform; -import org.datavec.api.transform.analysis.columns.ColumnAnalysis; -import org.datavec.api.transform.condition.column.ColumnCondition; -import org.datavec.api.transform.filter.Filter; -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.rank.CalculateSortedRank; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.transform.sequence.window.WindowFunction; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.primitives.Pair; -import org.nd4j.serde.json.LegacyIActivationDeserializer; -import org.nd4j.serde.json.LegacyILossFunctionDeserializer; +import org.datavec.api.transform.serde.legacy.LegacyJsonFormat; import org.nd4j.shade.jackson.annotation.JsonAutoDetect; -import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.PropertyAccessor; -import org.nd4j.shade.jackson.databind.*; -import org.nd4j.shade.jackson.databind.cfg.MapperConfig; -import org.nd4j.shade.jackson.databind.introspect.Annotated; -import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass; -import org.nd4j.shade.jackson.databind.introspect.AnnotationMap; -import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector; -import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder; +import org.nd4j.shade.jackson.databind.DeserializationFeature; +import org.nd4j.shade.jackson.databind.MapperFeature; +import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.datatype.joda.JodaModule; -import java.lang.annotation.Annotation; -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; - /** * JSON mappers for deserializing neural net configurations, etc. * @@ -62,38 +35,9 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class JsonMappers { - /** - * This system property is provided as an alternative to {@link #registerLegacyCustomClassesForJSON(Class[])} - * Classes can be specified in comma-separated format - */ - public static String CUSTOM_REGISTRATION_PROPERTY = "org.datavec.config.custom.legacyclasses"; - - static { - String p = System.getProperty(CUSTOM_REGISTRATION_PROPERTY); - if(p != null && !p.isEmpty()){ - String[] split = p.split(","); - List> list = new ArrayList<>(); - for(String s : split){ - try{ - Class c = Class.forName(s); - list.add(c); - } catch (Throwable t){ - log.warn("Error parsing {} system property: class \"{}\" could not be loaded",CUSTOM_REGISTRATION_PROPERTY, s, t); - } - } - - if(list.size() > 0){ - try { - registerLegacyCustomClassesForJSONList(list); - } catch (Throwable t){ - log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",CUSTOM_REGISTRATION_PROPERTY, t); - } - } - } - } - private static ObjectMapper jsonMapper; private static ObjectMapper yamlMapper; + private static ObjectMapper legacyMapper; //For 1.0.0-alpha and earlier TransformProcess etc static { jsonMapper = new ObjectMapper(); @@ -102,117 +46,12 @@ public class JsonMappers { configureMapper(yamlMapper); } - private static Map legacyMappers = new ConcurrentHashMap<>(); - - - /** - * Register a set of classes (Transform, Filter, etc) for JSON deserialization.
- *
- * This is required ONLY when BOTH of the following conditions are met:
- * 1. You want to load a serialized TransformProcess, saved in 1.0.0-alpha or before, AND
- * 2. The serialized TransformProcess has a custom Transform, Filter, etc (i.e., one not defined in DL4J)
- *
- * By passing the classes of these custom classes here, DataVec should be able to deserialize them, in spite of the JSON - * format change between versions. - * - * @param classes Classes to register - */ - public static void registerLegacyCustomClassesForJSON(Class... classes) { - registerLegacyCustomClassesForJSONList(Arrays.>asList(classes)); - } - - /** - * @see #registerLegacyCustomClassesForJSON(Class[]) - */ - public static void registerLegacyCustomClassesForJSONList(List> classes){ - //Default names (i.e., old format for custom JSON format) - List> list = new ArrayList<>(); - for(Class c : classes){ - list.add(new Pair(c.getSimpleName(), c)); + public static synchronized ObjectMapper getLegacyMapper(){ + if(legacyMapper == null){ + legacyMapper = LegacyJsonFormat.legacyMapper(); + configureMapper(legacyMapper); } - registerLegacyCustomClassesForJSON(list); - } - - /** - * Set of classes that can be registered for legacy deserialization. - */ - private static List> REGISTERABLE_CUSTOM_CLASSES = (List>) Arrays.>asList( - Transform.class, - ColumnAnalysis.class, - ColumnCondition.class, - Filter.class, - ColumnMetaData.class, - CalculateSortedRank.class, - Schema.class, - SequenceComparator.class, - SequenceSplit.class, - WindowFunction.class, - Writable.class, - WritableComparator.class - ); - - /** - * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution - * ONLY) for JSON deserialization, with custom names.
- * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])} - * but is added in case it is required in non-standard circumstances. - */ - public static void registerLegacyCustomClassesForJSON(List> classes){ - for(Pair p : classes){ - String s = p.getFirst(); - Class c = p.getRight(); - //Check if it's a valid class to register... - boolean found = false; - for( Class c2 : REGISTERABLE_CUSTOM_CLASSES){ - if(c2.isAssignableFrom(c)){ - Map map = LegacyMappingHelper.legacyMappingForClass(c2); - map.put(p.getFirst(), p.getSecond().getName()); - found = true; - } - } - - if(!found){ - throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " + - c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES); - } - } - } - - - /** - * Get the legacy JSON mapper for the specified class.
- * - * NOTE: This is intended for internal backward-compatibility use. - * - * Note to developers: The following JSON mappers are for handling legacy format JSON. - * Note that after 1.0.0-alpha, the JSON subtype format for Transforms, Filters, Conditions etc were changed from - * a wrapper object, to an "@class" field. However, to not break all saved transforms networks, these mappers are - * part of the solution.
- *
- * How legacy loading works (same pattern for all types - Transform, Filter, Condition etc)
- * 1. Transforms etc JSON that has a "@class" field are deserialized as normal
- * 2. Transforms JSON that don't have such a field are mapped (via Layer @JsonTypeInfo) to LegacyMappingHelper.TransformHelper
- * 3. LegacyMappingHelper.TransformHelper has a @JsonDeserialize annotation - we use LegacyMappingHelper.LegacyTransformDeserializer to handle it
- * 4. LegacyTransformDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names - * 5. BaseLegacyDeserializer (that LegacyTransformDeserializer extends) does a lookup and handles the deserialization - * - * Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format, - * as it'll fail due to not having the expected "@class" annotation. - * Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified - * class anyway. The ignoring is done via an annotation introspector, defined below in this class. - * However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of - * all types - if we did, then any nested types would fail (i.e., an Condition in a Transform - the Transform couldn't - * be deserialized correctly, as the annotation would be ignored). - * - */ - public static synchronized ObjectMapper getLegacyMapperFor(@NonNull Class clazz){ - if(!legacyMappers.containsKey(clazz)){ - ObjectMapper m = new ObjectMapper(); - configureMapper(m); - m.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(clazz))); - legacyMappers.put(clazz, m); - } - return legacyMappers.get(clazz); + return legacyMapper; } /** @@ -237,61 +76,7 @@ public class JsonMappers { ret.enable(SerializationFeature.INDENT_OUTPUT); ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + ret.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); //Need this otherwise JsonProperty annotations on constructors won't be seen } - - /** - * Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc. - * This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring - * a set of JsonTypeInfo annotations - */ - @AllArgsConstructor - private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector { - - private List classList; - - @Override - protected TypeResolverBuilder _findTypeResolver(MapperConfig config, Annotated ann, JavaType baseType) { - if(ann instanceof AnnotatedClass){ - AnnotatedClass c = (AnnotatedClass)ann; - Class annClass = c.getAnnotated(); - - boolean isAssignable = false; - for(Class c2 : classList){ - if(c2.isAssignableFrom(annClass)){ - isAssignable = true; - break; - } - } - - if( isAssignable ){ - AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations(); - if(annotations == null || annotations.annotations() == null){ - //Probably not necessary - but here for safety - return super._findTypeResolver(config, ann, baseType); - } - - AnnotationMap newMap = null; - for(Annotation a : annotations.annotations()){ - Class annType = a.annotationType(); - if(annType == JsonTypeInfo.class){ - //Ignore the JsonTypeInfo annotation on the Layer class - continue; - } - if(newMap == null){ - newMap = new AnnotationMap(); - } - newMap.add(a); - } - if(newMap == null) - return null; - - //Pass the remaining annotations (if any) to the original introspector - AnnotatedClass ann2 = c.withAnnotations(newMap); - return super._findTypeResolver(config, ann2, baseType); - } - } - return super._findTypeResolver(config, ann, baseType); - } - } } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java deleted file mode 100644 index 5a9b48a7c..000000000 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.api.transform.serde.legacy; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.Getter; -import org.datavec.api.transform.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.Map; - -@AllArgsConstructor -@Data -public class GenericLegacyDeserializer extends BaseLegacyDeserializer { - - @Getter - protected final Class deserializedType; - @Getter - protected final Map legacyNamesMap; - - @Override - public ObjectMapper getLegacyJsonMapper() { - return JsonMappers.getLegacyMapperFor(getDeserializedType()); - } -} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java new file mode 100644 index 000000000..8df741a49 --- /dev/null +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java @@ -0,0 +1,267 @@ +package org.datavec.api.transform.serde.legacy; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.datavec.api.transform.Transform; +import org.datavec.api.transform.analysis.columns.*; +import org.datavec.api.transform.condition.BooleanCondition; +import org.datavec.api.transform.condition.Condition; +import org.datavec.api.transform.condition.column.*; +import org.datavec.api.transform.condition.sequence.SequenceLengthCondition; +import org.datavec.api.transform.condition.string.StringRegexColumnCondition; +import org.datavec.api.transform.filter.ConditionFilter; +import org.datavec.api.transform.filter.Filter; +import org.datavec.api.transform.filter.FilterInvalidValues; +import org.datavec.api.transform.filter.InvalidNumColumns; +import org.datavec.api.transform.metadata.*; +import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform; +import org.datavec.api.transform.ndarray.NDArrayDistanceTransform; +import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform; +import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform; +import org.datavec.api.transform.rank.CalculateSortedRank; +import org.datavec.api.transform.schema.Schema; +import org.datavec.api.transform.schema.SequenceSchema; +import org.datavec.api.transform.sequence.ReduceSequenceTransform; +import org.datavec.api.transform.sequence.SequenceComparator; +import org.datavec.api.transform.sequence.SequenceSplit; +import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator; +import org.datavec.api.transform.sequence.comparator.StringComparator; +import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation; +import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence; +import org.datavec.api.transform.sequence.trim.SequenceTrimTransform; +import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction; +import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform; +import org.datavec.api.transform.sequence.window.TimeWindowFunction; +import org.datavec.api.transform.sequence.window.WindowFunction; +import org.datavec.api.transform.stringreduce.IStringReducer; +import org.datavec.api.transform.stringreduce.StringReducer; +import org.datavec.api.transform.transform.categorical.*; +import org.datavec.api.transform.transform.column.*; +import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform; +import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform; +import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault; +import org.datavec.api.transform.transform.doubletransform.*; +import org.datavec.api.transform.transform.integer.*; +import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform; +import org.datavec.api.transform.transform.longtransform.LongMathOpTransform; +import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; +import org.datavec.api.transform.transform.parse.ParseDoubleTransform; +import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform; +import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform; +import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform; +import org.datavec.api.transform.transform.string.*; +import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform; +import org.datavec.api.transform.transform.time.StringToTimeTransform; +import org.datavec.api.transform.transform.time.TimeMathOpTransform; +import org.datavec.api.writable.*; +import org.datavec.api.writable.comparator.*; +import org.nd4j.shade.jackson.annotation.JsonInclude; +import org.nd4j.shade.jackson.annotation.JsonSubTypes; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +/** + * This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override + * the existing annotations. + * In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field". + * We use these mixins to allow us to still load the old format + * + * @author Alex Black + */ +public class LegacyJsonFormat { + + private LegacyJsonFormat(){ } + + /** + * Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before + * @return Object mapper + */ + public static ObjectMapper legacyMapper(){ + ObjectMapper om = new ObjectMapper(); + om.addMixIn(Schema.class, SchemaMixin.class); + om.addMixIn(ColumnMetaData.class, ColumnMetaDataMixin.class); + om.addMixIn(Transform.class, TransformMixin.class); + om.addMixIn(Condition.class, ConditionMixin.class); + om.addMixIn(Writable.class, WritableMixin.class); + om.addMixIn(Filter.class, FilterMixin.class); + om.addMixIn(SequenceComparator.class, SequenceComparatorMixin.class); + om.addMixIn(SequenceSplit.class, SequenceSplitMixin.class); + om.addMixIn(WindowFunction.class, WindowFunctionMixin.class); + om.addMixIn(CalculateSortedRank.class, CalculateSortedRankMixin.class); + om.addMixIn(WritableComparator.class, WritableComparatorMixin.class); + om.addMixIn(ColumnAnalysis.class, ColumnAnalysisMixin.class); + om.addMixIn(IStringReducer.class, IStringReducerMixin.class); + return om; + } + + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes({@JsonSubTypes.Type(value = Schema.class, name = "Schema"), + @JsonSubTypes.Type(value = SequenceSchema.class, name = "SequenceSchema")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class SchemaMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes({@JsonSubTypes.Type(value = BinaryMetaData.class, name = "Binary"), + @JsonSubTypes.Type(value = BooleanMetaData.class, name = "Boloean"), + @JsonSubTypes.Type(value = CategoricalMetaData.class, name = "Categorical"), + @JsonSubTypes.Type(value = DoubleMetaData.class, name = "Double"), + @JsonSubTypes.Type(value = FloatMetaData.class, name = "Float"), + @JsonSubTypes.Type(value = IntegerMetaData.class, name = "Integer"), + @JsonSubTypes.Type(value = LongMetaData.class, name = "Long"), + @JsonSubTypes.Type(value = NDArrayMetaData.class, name = "NDArray"), + @JsonSubTypes.Type(value = StringMetaData.class, name = "String"), + @JsonSubTypes.Type(value = TimeMetaData.class, name = "Time")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ColumnMetaDataMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank"), + @JsonSubTypes.Type(value = CategoricalToIntegerTransform.class, name = "CategoricalToIntegerTransform"), + @JsonSubTypes.Type(value = CategoricalToOneHotTransform.class, name = "CategoricalToOneHotTransform"), + @JsonSubTypes.Type(value = IntegerToCategoricalTransform.class, name = "IntegerToCategoricalTransform"), + @JsonSubTypes.Type(value = StringToCategoricalTransform.class, name = "StringToCategoricalTransform"), + @JsonSubTypes.Type(value = DuplicateColumnsTransform.class, name = "DuplicateColumnsTransform"), + @JsonSubTypes.Type(value = RemoveColumnsTransform.class, name = "RemoveColumnsTransform"), + @JsonSubTypes.Type(value = RenameColumnsTransform.class, name = "RenameColumnsTransform"), + @JsonSubTypes.Type(value = ReorderColumnsTransform.class, name = "ReorderColumnsTransform"), + @JsonSubTypes.Type(value = ConditionalCopyValueTransform.class, name = "ConditionalCopyValueTransform"), + @JsonSubTypes.Type(value = ConditionalReplaceValueTransform.class, name = "ConditionalReplaceValueTransform"), + @JsonSubTypes.Type(value = ConditionalReplaceValueTransformWithDefault.class, name = "ConditionalReplaceValueTransformWithDefault"), + @JsonSubTypes.Type(value = DoubleColumnsMathOpTransform.class, name = "DoubleColumnsMathOpTransform"), + @JsonSubTypes.Type(value = DoubleMathOpTransform.class, name = "DoubleMathOpTransform"), + @JsonSubTypes.Type(value = Log2Normalizer.class, name = "Log2Normalizer"), + @JsonSubTypes.Type(value = MinMaxNormalizer.class, name = "MinMaxNormalizer"), + @JsonSubTypes.Type(value = StandardizeNormalizer.class, name = "StandardizeNormalizer"), + @JsonSubTypes.Type(value = SubtractMeanNormalizer.class, name = "SubtractMeanNormalizer"), + @JsonSubTypes.Type(value = IntegerColumnsMathOpTransform.class, name = "IntegerColumnsMathOpTransform"), + @JsonSubTypes.Type(value = IntegerMathOpTransform.class, name = "IntegerMathOpTransform"), + @JsonSubTypes.Type(value = ReplaceEmptyIntegerWithValueTransform.class, name = "ReplaceEmptyIntegerWithValueTransform"), + @JsonSubTypes.Type(value = ReplaceInvalidWithIntegerTransform.class, name = "ReplaceInvalidWithIntegerTransform"), + @JsonSubTypes.Type(value = LongColumnsMathOpTransform.class, name = "LongColumnsMathOpTransform"), + @JsonSubTypes.Type(value = LongMathOpTransform.class, name = "LongMathOpTransform"), + @JsonSubTypes.Type(value = MapAllStringsExceptListTransform.class, name = "MapAllStringsExceptListTransform"), + @JsonSubTypes.Type(value = RemoveWhiteSpaceTransform.class, name = "RemoveWhiteSpaceTransform"), + @JsonSubTypes.Type(value = ReplaceEmptyStringTransform.class, name = "ReplaceEmptyStringTransform"), + @JsonSubTypes.Type(value = ReplaceStringTransform.class, name = "ReplaceStringTransform"), + @JsonSubTypes.Type(value = StringListToCategoricalSetTransform.class, name = "StringListToCategoricalSetTransform"), + @JsonSubTypes.Type(value = StringMapTransform.class, name = "StringMapTransform"), + @JsonSubTypes.Type(value = DeriveColumnsFromTimeTransform.class, name = "DeriveColumnsFromTimeTransform"), + @JsonSubTypes.Type(value = StringToTimeTransform.class, name = "StringToTimeTransform"), + @JsonSubTypes.Type(value = TimeMathOpTransform.class, name = "TimeMathOpTransform"), + @JsonSubTypes.Type(value = ReduceSequenceByWindowTransform.class, name = "ReduceSequenceByWindowTransform"), + @JsonSubTypes.Type(value = DoubleMathFunctionTransform.class, name = "DoubleMathFunctionTransform"), + @JsonSubTypes.Type(value = AddConstantColumnTransform.class, name = "AddConstantColumnTransform"), + @JsonSubTypes.Type(value = RemoveAllColumnsExceptForTransform.class, name = "RemoveAllColumnsExceptForTransform"), + @JsonSubTypes.Type(value = ParseDoubleTransform.class, name = "ParseDoubleTransform"), + @JsonSubTypes.Type(value = ConvertToString.class, name = "ConvertToStringTransform"), + @JsonSubTypes.Type(value = AppendStringColumnTransform.class, name = "AppendStringColumnTransform"), + @JsonSubTypes.Type(value = SequenceDifferenceTransform.class, name = "SequenceDifferenceTransform"), + @JsonSubTypes.Type(value = ReduceSequenceTransform.class, name = "ReduceSequenceTransform"), + @JsonSubTypes.Type(value = SequenceMovingWindowReduceTransform.class, name = "SequenceMovingWindowReduceTransform"), + @JsonSubTypes.Type(value = IntegerToOneHotTransform.class, name = "IntegerToOneHotTransform"), + @JsonSubTypes.Type(value = SequenceTrimTransform.class, name = "SequenceTrimTransform"), + @JsonSubTypes.Type(value = SequenceOffsetTransform.class, name = "SequenceOffsetTransform"), + @JsonSubTypes.Type(value = NDArrayColumnsMathOpTransform.class, name = "NDArrayColumnsMathOpTransform"), + @JsonSubTypes.Type(value = NDArrayDistanceTransform.class, name = "NDArrayDistanceTransform"), + @JsonSubTypes.Type(value = NDArrayMathFunctionTransform.class, name = "NDArrayMathFunctionTransform"), + @JsonSubTypes.Type(value = NDArrayScalarOpTransform.class, name = "NDArrayScalarOpTransform"), + @JsonSubTypes.Type(value = ChangeCaseStringTransform.class, name = "ChangeCaseStringTransform"), + @JsonSubTypes.Type(value = ConcatenateStringColumns.class, name = "ConcatenateStringColumns"), + @JsonSubTypes.Type(value = StringListToCountsNDArrayTransform.class, name = "StringListToCountsNDArrayTransform"), + @JsonSubTypes.Type(value = StringListToIndicesNDArrayTransform.class, name = "StringListToIndicesNDArrayTransform"), + @JsonSubTypes.Type(value = PivotTransform.class, name = "PivotTransform"), + @JsonSubTypes.Type(value = TextToCharacterIndexTransform.class, name = "TextToCharacterIndexTransform")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class TransformMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = TrivialColumnCondition.class, name = "TrivialColumnCondition"), + @JsonSubTypes.Type(value = CategoricalColumnCondition.class, name = "CategoricalColumnCondition"), + @JsonSubTypes.Type(value = DoubleColumnCondition.class, name = "DoubleColumnCondition"), + @JsonSubTypes.Type(value = IntegerColumnCondition.class, name = "IntegerColumnCondition"), + @JsonSubTypes.Type(value = LongColumnCondition.class, name = "LongColumnCondition"), + @JsonSubTypes.Type(value = NullWritableColumnCondition.class, name = "NullWritableColumnCondition"), + @JsonSubTypes.Type(value = StringColumnCondition.class, name = "StringColumnCondition"), + @JsonSubTypes.Type(value = TimeColumnCondition.class, name = "TimeColumnCondition"), + @JsonSubTypes.Type(value = StringRegexColumnCondition.class, name = "StringRegexColumnCondition"), + @JsonSubTypes.Type(value = BooleanCondition.class, name = "BooleanCondition"), + @JsonSubTypes.Type(value = NaNColumnCondition.class, name = "NaNColumnCondition"), + @JsonSubTypes.Type(value = InfiniteColumnCondition.class, name = "InfiniteColumnCondition"), + @JsonSubTypes.Type(value = SequenceLengthCondition.class, name = "SequenceLengthCondition")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ConditionMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ArrayWritable.class, name = "ArrayWritable"), + @JsonSubTypes.Type(value = BooleanWritable.class, name = "BooleanWritable"), + @JsonSubTypes.Type(value = ByteWritable.class, name = "ByteWritable"), + @JsonSubTypes.Type(value = DoubleWritable.class, name = "DoubleWritable"), + @JsonSubTypes.Type(value = FloatWritable.class, name = "FloatWritable"), + @JsonSubTypes.Type(value = IntWritable.class, name = "IntWritable"), + @JsonSubTypes.Type(value = LongWritable.class, name = "LongWritable"), + @JsonSubTypes.Type(value = NullWritable.class, name = "NullWritable"), + @JsonSubTypes.Type(value = Text.class, name = "Text"), + @JsonSubTypes.Type(value = BytesWritable.class, name = "BytesWritable")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class WritableMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ConditionFilter.class, name = "ConditionFilter"), + @JsonSubTypes.Type(value = FilterInvalidValues.class, name = "FilterInvalidValues"), + @JsonSubTypes.Type(value = InvalidNumColumns.class, name = "InvalidNumCols")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class FilterMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = NumericalColumnComparator.class, name = "NumericalColumnComparator"), + @JsonSubTypes.Type(value = StringComparator.class, name = "StringComparator")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class SequenceComparatorMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = SequenceSplitTimeSeparation.class, name = "SequenceSplitTimeSeparation"), + @JsonSubTypes.Type(value = SplitMaxLengthSequence.class, name = "SplitMaxLengthSequence")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class SequenceSplitMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = TimeWindowFunction.class, name = "TimeWindowFunction"), + @JsonSubTypes.Type(value = OverlappingTimeWindowFunction.class, name = "OverlappingTimeWindowFunction")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class WindowFunctionMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class CalculateSortedRankMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = DoubleWritableComparator.class, name = "DoubleWritableComparator"), + @JsonSubTypes.Type(value = FloatWritableComparator.class, name = "FloatWritableComparator"), + @JsonSubTypes.Type(value = IntWritableComparator.class, name = "IntWritableComparator"), + @JsonSubTypes.Type(value = LongWritableComparator.class, name = "LongWritableComparator"), + @JsonSubTypes.Type(value = TextWritableComparator.class, name = "TextWritableComparator")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class WritableComparatorMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = BytesAnalysis.class, name = "BytesAnalysis"), + @JsonSubTypes.Type(value = CategoricalAnalysis.class, name = "CategoricalAnalysis"), + @JsonSubTypes.Type(value = DoubleAnalysis.class, name = "DoubleAnalysis"), + @JsonSubTypes.Type(value = IntegerAnalysis.class, name = "IntegerAnalysis"), + @JsonSubTypes.Type(value = LongAnalysis.class, name = "LongAnalysis"), + @JsonSubTypes.Type(value = StringAnalysis.class, name = "StringAnalysis"), + @JsonSubTypes.Type(value = TimeAnalysis.class, name = "TimeAnalysis")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ColumnAnalysisMixin{ } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = StringReducer.class, name = "StringReducer")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class IStringReducerMixin{ } +} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java deleted file mode 100644 index c4b478278..000000000 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java +++ /dev/null @@ -1,535 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.api.transform.serde.legacy; - -import org.datavec.api.transform.Transform; -import org.datavec.api.transform.analysis.columns.*; -import org.datavec.api.transform.condition.BooleanCondition; -import org.datavec.api.transform.condition.Condition; -import org.datavec.api.transform.condition.column.*; -import org.datavec.api.transform.condition.sequence.SequenceLengthCondition; -import org.datavec.api.transform.condition.string.StringRegexColumnCondition; -import org.datavec.api.transform.filter.ConditionFilter; -import org.datavec.api.transform.filter.Filter; -import org.datavec.api.transform.filter.FilterInvalidValues; -import org.datavec.api.transform.filter.InvalidNumColumns; -import org.datavec.api.transform.metadata.*; -import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform; -import org.datavec.api.transform.ndarray.NDArrayDistanceTransform; -import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform; -import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform; -import org.datavec.api.transform.rank.CalculateSortedRank; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.transform.sequence.ReduceSequenceTransform; -import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator; -import org.datavec.api.transform.sequence.comparator.StringComparator; -import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation; -import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence; -import org.datavec.api.transform.sequence.trim.SequenceTrimTransform; -import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction; -import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform; -import org.datavec.api.transform.sequence.window.TimeWindowFunction; -import org.datavec.api.transform.sequence.window.WindowFunction; -import org.datavec.api.transform.stringreduce.IStringReducer; -import org.datavec.api.transform.stringreduce.StringReducer; -import org.datavec.api.transform.transform.categorical.*; -import org.datavec.api.transform.transform.column.*; -import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform; -import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform; -import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault; -import org.datavec.api.transform.transform.doubletransform.*; -import org.datavec.api.transform.transform.integer.*; -import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform; -import org.datavec.api.transform.transform.longtransform.LongMathOpTransform; -import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; -import org.datavec.api.transform.transform.nlp.TextToTermIndexSequenceTransform; -import org.datavec.api.transform.transform.parse.ParseDoubleTransform; -import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform; -import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform; -import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform; -import org.datavec.api.transform.transform.string.*; -import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform; -import org.datavec.api.transform.transform.time.StringToTimeTransform; -import org.datavec.api.transform.transform.time.TimeMathOpTransform; -import org.datavec.api.writable.*; -import org.datavec.api.writable.comparator.*; -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -import java.util.HashMap; -import java.util.Map; - -public class LegacyMappingHelper { - - public static Map legacyMappingForClass(Class c){ - //Need to be able to get the map - and they need to be mutable... - switch (c.getSimpleName()){ - case "Transform": - return getLegacyMappingImageTransform(); - case "ColumnAnalysis": - return getLegacyMappingColumnAnalysis(); - case "Condition": - return getLegacyMappingCondition(); - case "Filter": - return getLegacyMappingFilter(); - case "ColumnMetaData": - return mapColumnMetaData; - case "CalculateSortedRank": - return mapCalculateSortedRank; - case "Schema": - return mapSchema; - case "SequenceComparator": - return mapSequenceComparator; - case "SequenceSplit": - return mapSequenceSplit; - case "WindowFunction": - return mapWindowFunction; - case "IStringReducer": - return mapIStringReducer; - case "Writable": - return mapWritable; - case "WritableComparator": - return mapWritableComparator; - case "ImageTransform": - return mapImageTransform; - default: - //Should never happen - throw new IllegalArgumentException("No legacy mapping available for class " + c.getName()); - } - } - - private static Map mapTransform; - private static Map mapColumnAnalysis; - private static Map mapCondition; - private static Map mapFilter; - private static Map mapColumnMetaData; - private static Map mapCalculateSortedRank; - private static Map mapSchema; - private static Map mapSequenceComparator; - private static Map mapSequenceSplit; - private static Map mapWindowFunction; - private static Map mapIStringReducer; - private static Map mapWritable; - private static Map mapWritableComparator; - private static Map mapImageTransform; - - private static synchronized Map getLegacyMappingTransform(){ - - if(mapTransform == null) { - //The following classes all used their class short name - Map m = new HashMap<>(); - m.put("CategoricalToIntegerTransform", CategoricalToIntegerTransform.class.getName()); - m.put("CategoricalToOneHotTransform", CategoricalToOneHotTransform.class.getName()); - m.put("IntegerToCategoricalTransform", IntegerToCategoricalTransform.class.getName()); - m.put("StringToCategoricalTransform", StringToCategoricalTransform.class.getName()); - m.put("DuplicateColumnsTransform", DuplicateColumnsTransform.class.getName()); - m.put("RemoveColumnsTransform", RemoveColumnsTransform.class.getName()); - m.put("RenameColumnsTransform", RenameColumnsTransform.class.getName()); - m.put("ReorderColumnsTransform", ReorderColumnsTransform.class.getName()); - m.put("ConditionalCopyValueTransform", ConditionalCopyValueTransform.class.getName()); - m.put("ConditionalReplaceValueTransform", ConditionalReplaceValueTransform.class.getName()); - m.put("ConditionalReplaceValueTransformWithDefault", ConditionalReplaceValueTransformWithDefault.class.getName()); - m.put("DoubleColumnsMathOpTransform", DoubleColumnsMathOpTransform.class.getName()); - m.put("DoubleMathOpTransform", DoubleMathOpTransform.class.getName()); - m.put("Log2Normalizer", Log2Normalizer.class.getName()); - m.put("MinMaxNormalizer", MinMaxNormalizer.class.getName()); - m.put("StandardizeNormalizer", StandardizeNormalizer.class.getName()); - m.put("SubtractMeanNormalizer", SubtractMeanNormalizer.class.getName()); - m.put("IntegerColumnsMathOpTransform", IntegerColumnsMathOpTransform.class.getName()); - m.put("IntegerMathOpTransform", IntegerMathOpTransform.class.getName()); - m.put("ReplaceEmptyIntegerWithValueTransform", ReplaceEmptyIntegerWithValueTransform.class.getName()); - m.put("ReplaceInvalidWithIntegerTransform", ReplaceInvalidWithIntegerTransform.class.getName()); - m.put("LongColumnsMathOpTransform", LongColumnsMathOpTransform.class.getName()); - m.put("LongMathOpTransform", LongMathOpTransform.class.getName()); - m.put("MapAllStringsExceptListTransform", MapAllStringsExceptListTransform.class.getName()); - m.put("RemoveWhiteSpaceTransform", RemoveWhiteSpaceTransform.class.getName()); - m.put("ReplaceEmptyStringTransform", ReplaceEmptyStringTransform.class.getName()); - m.put("ReplaceStringTransform", ReplaceStringTransform.class.getName()); - m.put("StringListToCategoricalSetTransform", StringListToCategoricalSetTransform.class.getName()); - m.put("StringMapTransform", StringMapTransform.class.getName()); - m.put("DeriveColumnsFromTimeTransform", DeriveColumnsFromTimeTransform.class.getName()); - m.put("StringToTimeTransform", StringToTimeTransform.class.getName()); - m.put("TimeMathOpTransform", TimeMathOpTransform.class.getName()); - m.put("ReduceSequenceByWindowTransform", ReduceSequenceByWindowTransform.class.getName()); - m.put("DoubleMathFunctionTransform", DoubleMathFunctionTransform.class.getName()); - m.put("AddConstantColumnTransform", AddConstantColumnTransform.class.getName()); - m.put("RemoveAllColumnsExceptForTransform", RemoveAllColumnsExceptForTransform.class.getName()); - m.put("ParseDoubleTransform", ParseDoubleTransform.class.getName()); - m.put("ConvertToStringTransform", ConvertToString.class.getName()); - m.put("AppendStringColumnTransform", AppendStringColumnTransform.class.getName()); - m.put("SequenceDifferenceTransform", SequenceDifferenceTransform.class.getName()); - m.put("ReduceSequenceTransform", ReduceSequenceTransform.class.getName()); - m.put("SequenceMovingWindowReduceTransform", SequenceMovingWindowReduceTransform.class.getName()); - m.put("IntegerToOneHotTransform", IntegerToOneHotTransform.class.getName()); - m.put("SequenceTrimTransform", SequenceTrimTransform.class.getName()); - m.put("SequenceOffsetTransform", SequenceOffsetTransform.class.getName()); - m.put("NDArrayColumnsMathOpTransform", NDArrayColumnsMathOpTransform.class.getName()); - m.put("NDArrayDistanceTransform", NDArrayDistanceTransform.class.getName()); - m.put("NDArrayMathFunctionTransform", NDArrayMathFunctionTransform.class.getName()); - m.put("NDArrayScalarOpTransform", NDArrayScalarOpTransform.class.getName()); - m.put("ChangeCaseStringTransform", ChangeCaseStringTransform.class.getName()); - m.put("ConcatenateStringColumns", ConcatenateStringColumns.class.getName()); - m.put("StringListToCountsNDArrayTransform", StringListToCountsNDArrayTransform.class.getName()); - m.put("StringListToIndicesNDArrayTransform", StringListToIndicesNDArrayTransform.class.getName()); - m.put("PivotTransform", PivotTransform.class.getName()); - m.put("TextToCharacterIndexTransform", TextToCharacterIndexTransform.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(TextToTermIndexSequenceTransform.class.getSimpleName(), TextToTermIndexSequenceTransform.class.getName()); - m.put(ConvertToInteger.class.getSimpleName(), ConvertToInteger.class.getName()); - m.put(ConvertToDouble.class.getSimpleName(), ConvertToDouble.class.getName()); - - mapTransform = m; - } - - return mapTransform; - } - - private static Map getLegacyMappingColumnAnalysis(){ - if(mapColumnAnalysis == null) { - Map m = new HashMap<>(); - m.put("BytesAnalysis", BytesAnalysis.class.getName()); - m.put("CategoricalAnalysis", CategoricalAnalysis.class.getName()); - m.put("DoubleAnalysis", DoubleAnalysis.class.getName()); - m.put("IntegerAnalysis", IntegerAnalysis.class.getName()); - m.put("LongAnalysis", LongAnalysis.class.getName()); - m.put("StringAnalysis", StringAnalysis.class.getName()); - m.put("TimeAnalysis", TimeAnalysis.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(NDArrayAnalysis.class.getSimpleName(), NDArrayAnalysis.class.getName()); - - mapColumnAnalysis = m; - } - - return mapColumnAnalysis; - } - - private static Map getLegacyMappingCondition(){ - if(mapCondition == null) { - Map m = new HashMap<>(); - m.put("TrivialColumnCondition", TrivialColumnCondition.class.getName()); - m.put("CategoricalColumnCondition", CategoricalColumnCondition.class.getName()); - m.put("DoubleColumnCondition", DoubleColumnCondition.class.getName()); - m.put("IntegerColumnCondition", IntegerColumnCondition.class.getName()); - m.put("LongColumnCondition", LongColumnCondition.class.getName()); - m.put("NullWritableColumnCondition", NullWritableColumnCondition.class.getName()); - m.put("StringColumnCondition", StringColumnCondition.class.getName()); - m.put("TimeColumnCondition", TimeColumnCondition.class.getName()); - m.put("StringRegexColumnCondition", StringRegexColumnCondition.class.getName()); - m.put("BooleanCondition", BooleanCondition.class.getName()); - m.put("NaNColumnCondition", NaNColumnCondition.class.getName()); - m.put("InfiniteColumnCondition", InfiniteColumnCondition.class.getName()); - m.put("SequenceLengthCondition", SequenceLengthCondition.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(InvalidValueColumnCondition.class.getSimpleName(), InvalidValueColumnCondition.class.getName()); - m.put(BooleanColumnCondition.class.getSimpleName(), BooleanColumnCondition.class.getName()); - - mapCondition = m; - } - - return mapCondition; - } - - private static Map getLegacyMappingFilter(){ - if(mapFilter == null) { - Map m = new HashMap<>(); - m.put("ConditionFilter", ConditionFilter.class.getName()); - m.put("FilterInvalidValues", FilterInvalidValues.class.getName()); - m.put("InvalidNumCols", InvalidNumColumns.class.getName()); - - mapFilter = m; - } - return mapFilter; - } - - private static Map getLegacyMappingColumnMetaData(){ - if(mapColumnMetaData == null) { - Map m = new HashMap<>(); - m.put("Categorical", CategoricalMetaData.class.getName()); - m.put("Double", DoubleMetaData.class.getName()); - m.put("Float", FloatMetaData.class.getName()); - m.put("Integer", IntegerMetaData.class.getName()); - m.put("Long", LongMetaData.class.getName()); - m.put("String", StringMetaData.class.getName()); - m.put("Time", TimeMetaData.class.getName()); - m.put("NDArray", NDArrayMetaData.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(BooleanMetaData.class.getSimpleName(), BooleanMetaData.class.getName()); - m.put(BinaryMetaData.class.getSimpleName(), BinaryMetaData.class.getName()); - - mapColumnMetaData = m; - } - - return mapColumnMetaData; - } - - private static Map getLegacyMappingCalculateSortedRank(){ - if(mapCalculateSortedRank == null) { - Map m = new HashMap<>(); - m.put("CalculateSortedRank", CalculateSortedRank.class.getName()); - mapCalculateSortedRank = m; - } - return mapCalculateSortedRank; - } - - private static Map getLegacyMappingSchema(){ - if(mapSchema == null) { - Map m = new HashMap<>(); - m.put("Schema", Schema.class.getName()); - m.put("SequenceSchema", SequenceSchema.class.getName()); - - mapSchema = m; - } - return mapSchema; - } - - private static Map getLegacyMappingSequenceComparator(){ - if(mapSequenceComparator == null) { - Map m = new HashMap<>(); - m.put("NumericalColumnComparator", NumericalColumnComparator.class.getName()); - m.put("StringComparator", StringComparator.class.getName()); - - mapSequenceComparator = m; - } - return mapSequenceComparator; - } - - private static Map getLegacyMappingSequenceSplit(){ - if(mapSequenceSplit == null) { - Map m = new HashMap<>(); - m.put("SequenceSplitTimeSeparation", SequenceSplitTimeSeparation.class.getName()); - m.put("SplitMaxLengthSequence", SplitMaxLengthSequence.class.getName()); - - mapSequenceSplit = m; - } - return mapSequenceSplit; - } - - private static Map getLegacyMappingWindowFunction(){ - if(mapWindowFunction == null) { - Map m = new HashMap<>(); - m.put("TimeWindowFunction", TimeWindowFunction.class.getName()); - m.put("OverlappingTimeWindowFunction", OverlappingTimeWindowFunction.class.getName()); - - mapWindowFunction = m; - } - return mapWindowFunction; - } - - private static Map getLegacyMappingIStringReducer(){ - if(mapIStringReducer == null) { - Map m = new HashMap<>(); - m.put("StringReducer", StringReducer.class.getName()); - - mapIStringReducer = m; - } - return mapIStringReducer; - } - - private static Map getLegacyMappingWritable(){ - if (mapWritable == null) { - Map m = new HashMap<>(); - m.put("ArrayWritable", ArrayWritable.class.getName()); - m.put("BooleanWritable", BooleanWritable.class.getName()); - m.put("ByteWritable", ByteWritable.class.getName()); - m.put("DoubleWritable", DoubleWritable.class.getName()); - m.put("FloatWritable", FloatWritable.class.getName()); - m.put("IntWritable", IntWritable.class.getName()); - m.put("LongWritable", LongWritable.class.getName()); - m.put("NullWritable", NullWritable.class.getName()); - m.put("Text", Text.class.getName()); - m.put("BytesWritable", BytesWritable.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(NDArrayWritable.class.getSimpleName(), NDArrayWritable.class.getName()); - - mapWritable = m; - } - - return mapWritable; - } - - private static Map getLegacyMappingWritableComparator(){ - if(mapWritableComparator == null) { - Map m = new HashMap<>(); - m.put("DoubleWritableComparator", DoubleWritableComparator.class.getName()); - m.put("FloatWritableComparator", FloatWritableComparator.class.getName()); - m.put("IntWritableComparator", IntWritableComparator.class.getName()); - m.put("LongWritableComparator", LongWritableComparator.class.getName()); - m.put("TextWritableComparator", TextWritableComparator.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(ByteWritable.Comparator.class.getSimpleName(), ByteWritable.Comparator.class.getName()); - m.put(FloatWritable.Comparator.class.getSimpleName(), FloatWritable.Comparator.class.getName()); - m.put(IntWritable.Comparator.class.getSimpleName(), IntWritable.Comparator.class.getName()); - m.put(BooleanWritable.Comparator.class.getSimpleName(), BooleanWritable.Comparator.class.getName()); - m.put(LongWritable.Comparator.class.getSimpleName(), LongWritable.Comparator.class.getName()); - m.put(Text.Comparator.class.getSimpleName(), Text.Comparator.class.getName()); - m.put(LongWritable.DecreasingComparator.class.getSimpleName(), LongWritable.DecreasingComparator.class.getName()); - m.put(DoubleWritable.Comparator.class.getSimpleName(), DoubleWritable.Comparator.class.getName()); - - mapWritableComparator = m; - } - - return mapWritableComparator; - } - - public static Map getLegacyMappingImageTransform(){ - if(mapImageTransform == null) { - Map m = new HashMap<>(); - m.put("EqualizeHistTransform", "org.datavec.image.transform.EqualizeHistTransform"); - m.put("RotateImageTransform", "org.datavec.image.transform.RotateImageTransform"); - m.put("ColorConversionTransform", "org.datavec.image.transform.ColorConversionTransform"); - m.put("WarpImageTransform", "org.datavec.image.transform.WarpImageTransform"); - m.put("BoxImageTransform", "org.datavec.image.transform.BoxImageTransform"); - m.put("CropImageTransform", "org.datavec.image.transform.CropImageTransform"); - m.put("FilterImageTransform", "org.datavec.image.transform.FilterImageTransform"); - m.put("FlipImageTransform", "org.datavec.image.transform.FlipImageTransform"); - m.put("LargestBlobCropTransform", "org.datavec.image.transform.LargestBlobCropTransform"); - m.put("ResizeImageTransform", "org.datavec.image.transform.ResizeImageTransform"); - m.put("RandomCropTransform", "org.datavec.image.transform.RandomCropTransform"); - m.put("ScaleImageTransform", "org.datavec.image.transform.ScaleImageTransform"); - - mapImageTransform = m; - } - return mapImageTransform; - } - - @JsonDeserialize(using = LegacyTransformDeserializer.class) - public static class TransformHelper { } - - public static class LegacyTransformDeserializer extends GenericLegacyDeserializer { - public LegacyTransformDeserializer() { - super(Transform.class, getLegacyMappingTransform()); - } - } - - @JsonDeserialize(using = LegacyColumnAnalysisDeserializer.class) - public static class ColumnAnalysisHelper { } - - public static class LegacyColumnAnalysisDeserializer extends GenericLegacyDeserializer { - public LegacyColumnAnalysisDeserializer() { - super(ColumnAnalysis.class, getLegacyMappingColumnAnalysis()); - } - } - - @JsonDeserialize(using = LegacyConditionDeserializer.class) - public static class ConditionHelper { } - - public static class LegacyConditionDeserializer extends GenericLegacyDeserializer { - public LegacyConditionDeserializer() { - super(Condition.class, getLegacyMappingCondition()); - } - } - - @JsonDeserialize(using = LegacyFilterDeserializer.class) - public static class FilterHelper { } - - public static class LegacyFilterDeserializer extends GenericLegacyDeserializer { - public LegacyFilterDeserializer() { - super(Filter.class, getLegacyMappingFilter()); - } - } - - @JsonDeserialize(using = LegacyColumnMetaDataDeserializer.class) - public static class ColumnMetaDataHelper { } - - public static class LegacyColumnMetaDataDeserializer extends GenericLegacyDeserializer { - public LegacyColumnMetaDataDeserializer() { - super(ColumnMetaData.class, getLegacyMappingColumnMetaData()); - } - } - - @JsonDeserialize(using = LegacyCalculateSortedRankDeserializer.class) - public static class CalculateSortedRankHelper { } - - public static class LegacyCalculateSortedRankDeserializer extends GenericLegacyDeserializer { - public LegacyCalculateSortedRankDeserializer() { - super(CalculateSortedRank.class, getLegacyMappingCalculateSortedRank()); - } - } - - @JsonDeserialize(using = LegacySchemaDeserializer.class) - public static class SchemaHelper { } - - public static class LegacySchemaDeserializer extends GenericLegacyDeserializer { - public LegacySchemaDeserializer() { - super(Schema.class, getLegacyMappingSchema()); - } - } - - @JsonDeserialize(using = LegacySequenceComparatorDeserializer.class) - public static class SequenceComparatorHelper { } - - public static class LegacySequenceComparatorDeserializer extends GenericLegacyDeserializer { - public LegacySequenceComparatorDeserializer() { - super(SequenceComparator.class, getLegacyMappingSequenceComparator()); - } - } - - @JsonDeserialize(using = LegacySequenceSplitDeserializer.class) - public static class SequenceSplitHelper { } - - public static class LegacySequenceSplitDeserializer extends GenericLegacyDeserializer { - public LegacySequenceSplitDeserializer() { - super(SequenceSplit.class, getLegacyMappingSequenceSplit()); - } - } - - @JsonDeserialize(using = LegacyWindowFunctionDeserializer.class) - public static class WindowFunctionHelper { } - - public static class LegacyWindowFunctionDeserializer extends GenericLegacyDeserializer { - public LegacyWindowFunctionDeserializer() { - super(WindowFunction.class, getLegacyMappingWindowFunction()); - } - } - - - @JsonDeserialize(using = LegacyIStringReducerDeserializer.class) - public static class IStringReducerHelper { } - - public static class LegacyIStringReducerDeserializer extends GenericLegacyDeserializer { - public LegacyIStringReducerDeserializer() { - super(IStringReducer.class, getLegacyMappingIStringReducer()); - } - } - - - @JsonDeserialize(using = LegacyWritableDeserializer.class) - public static class WritableHelper { } - - public static class LegacyWritableDeserializer extends GenericLegacyDeserializer { - public LegacyWritableDeserializer() { - super(Writable.class, getLegacyMappingWritable()); - } - } - - @JsonDeserialize(using = LegacyWritableComparatorDeserializer.class) - public static class WritableComparatorHelper { } - - public static class LegacyWritableComparatorDeserializer extends GenericLegacyDeserializer { - public LegacyWritableComparatorDeserializer() { - super(WritableComparator.class, getLegacyMappingWritableComparator()); - } - } -} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java index f43e189d8..54bd7b7c8 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.stringreduce; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -31,8 +30,7 @@ import java.util.List; * a single List */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.IStringReducerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface IStringReducer extends Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index e3ba797c8..c55d4d3bb 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -16,7 +16,7 @@ package org.datavec.api.util.ndarray; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import lombok.NonNull; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java index 2584d8b9f..ae5e3a567 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java index 72f4b3c5b..39f41c076 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java index 1b54e7d54..f0ab62bef 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java index 3803c8098..1c127e0f8 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java index 58cd45829..4a767a183 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java index 30eb5e25e..5085dd3f2 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java @@ -16,7 +16,6 @@ package org.datavec.api.writable; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.io.DataInput; @@ -60,8 +59,7 @@ import java.io.Serializable; * } *

*/ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.WritableHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Writable extends Serializable { /** * Serialize the fields of this object to out. diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java index 4d638124e..0a5ddddb8 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java @@ -16,7 +16,7 @@ package org.datavec.api.writable.batch; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Data; import lombok.NonNull; import org.datavec.api.writable.NDArrayWritable; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java index 07ef7ee56..b8e540f61 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java @@ -16,16 +16,13 @@ package org.datavec.api.writable.comparator; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; -import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.io.Serializable; import java.util.Comparator; -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.WritableComparatorHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface WritableComparator extends Comparator, Serializable { } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index 647c0af65..e8ce37bd3 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -16,7 +16,7 @@ package org.datavec.api.split; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.labels.ParentPathLabelGenerator; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java index e67fc1f61..c9fb57eb9 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java @@ -16,7 +16,7 @@ package org.datavec.api.split.parittion; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.datavec.api.conf.Configuration; import org.datavec.api.split.FileSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java index c98b58381..dff90f8b9 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java @@ -78,8 +78,9 @@ public class TestJsonYaml { public void testMissingPrimitives() { Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build(); - - String strJson = "{\n" + " \"Schema\" : {\n" + " \"columns\" : [ {\n" + " \"Double\" : {\n" + //Legacy format JSON + String strJson = "{\n" + " \"Schema\" : {\n" + + " \"columns\" : [ {\n" + " \"Double\" : {\n" + " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" + //" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test //" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index ce7d779dc..6dfacdd93 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -16,7 +16,7 @@ package org.datavec.api.writable; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.junit.Test; diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml index 4d4381790..645971a45 100644 --- a/datavec/datavec-arrow/pom.xml +++ b/datavec/datavec-arrow/pom.xml @@ -34,36 +34,6 @@ nd4j-arrow ${project.version}
- - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-xml - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-joda - ${spark2.jackson.version} - org.datavec datavec-api @@ -74,11 +44,6 @@ hppc ${hppc.version} - - com.google.guava - guava - ${guava.version} - org.apache.arrow arrow-vector diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index d80f12a2e..a962dd84a 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -16,7 +16,7 @@ package org.datavec.image.recordreader; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java deleted file mode 100644 index 5e7b09c12..000000000 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java +++ /dev/null @@ -1,35 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.image.serde; - -import org.datavec.api.transform.serde.legacy.GenericLegacyDeserializer; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; -import org.datavec.image.transform.ImageTransform; -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -public class LegacyImageMappingHelper { - - @JsonDeserialize(using = LegacyImageTransformDeserializer.class) - public static class ImageTransformHelper { } - - public static class LegacyImageTransformDeserializer extends GenericLegacyDeserializer { - public LegacyImageTransformDeserializer() { - super(ImageTransform.class, LegacyMappingHelper.getLegacyMappingImageTransform()); - } - } - -} diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java index 39239494b..afcdf894f 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java @@ -16,11 +16,8 @@ package org.datavec.image.transform; -import lombok.Data; import org.datavec.image.data.ImageWritable; -import org.datavec.image.serde.LegacyImageMappingHelper; import org.nd4j.shade.jackson.annotation.JsonInclude; -import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.util.Random; @@ -31,8 +28,7 @@ import java.util.Random; * @author saudet */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyImageMappingHelper.ImageTransformHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ImageTransform { /** diff --git a/datavec/datavec-hadoop/pom.xml b/datavec/datavec-hadoop/pom.xml index 38889228c..c95e6d3bc 100644 --- a/datavec/datavec-hadoop/pom.xml +++ b/datavec/datavec-hadoop/pom.xml @@ -50,11 +50,6 @@ netty ${netty.version} - - com.google.guava - guava - ${guava.version} - org.apache.commons commons-compress diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java index 01bef8fa9..58f7a57db 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.reader; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java index cf1a801f5..1cbe47176 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.reader; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java index 992c91312..faf41cbb4 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.reader; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java index d35becd7a..7cd112c63 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.writer; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.datavec.api.records.converter.RecordReaderConverter; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.SequenceRecordReader; diff --git a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java index 5360eae7e..276d79f88 100644 --- a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java +++ b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java @@ -16,7 +16,7 @@ package org.datavec.local.transforms.join; -import com.google.common.collect.Iterables; +import org.nd4j.shade.guava.collect.Iterables; import org.datavec.api.transform.join.Join; import org.datavec.api.writable.Writable; import org.datavec.local.transforms.functions.FlatMapFunctionAdapter; diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 940ad01cc..605b13b70 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -64,12 +64,6 @@ ${datavec.version} - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} - - joda-time joda-time @@ -106,40 +100,10 @@ ${snakeyaml.version} - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${jackson.version} - - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${jackson.version} - - com.typesafe.play play-java_2.11 - ${play.version} + ${playframework.version} com.google.code.findbugs @@ -161,25 +125,31 @@ com.typesafe.play play-json_2.11 - ${play.version} + ${playframework.version} com.typesafe.play play-server_2.11 - ${play.version} + ${playframework.version} com.typesafe.play play_2.11 - ${play.version} + ${playframework.version} com.typesafe.play play-netty-server_2.11 - ${play.version} + ${playframework.version} + + + + com.typesafe.akka + akka-cluster_2.11 + 2.5.23 @@ -194,6 +164,12 @@ jcommander ${jcommander.version} + + + org.apache.spark + spark-core_2.11 + ${spark.version} +
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java index 893ce4218..f20799905 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java @@ -24,12 +24,16 @@ import org.apache.commons.io.FileUtils; import org.datavec.api.transform.TransformProcess; import org.datavec.image.transform.ImageTransformProcess; import org.datavec.spark.transform.model.*; +import play.BuiltInComponents; import play.Mode; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; import java.io.File; import java.io.IOException; +import java.util.Base64; +import java.util.Random; import static play.mvc.Results.*; @@ -66,9 +70,6 @@ public class CSVSparkTransformServer extends SparkTransformServer { System.exit(1); } - RoutingDsl routingDsl = new RoutingDsl(); - - if (jsonPath != null) { String json = FileUtils.readFileToString(new File(jsonPath)); TransformProcess transformProcess = TransformProcess.fromJson(json); @@ -78,8 +79,26 @@ public class CSVSparkTransformServer extends SparkTransformServer { + "to /transformprocess"); } + //Set play secret key, if required + //http://www.playframework.com/documentation/latest/ApplicationSecret + String crypto = System.getProperty("play.crypto.secret"); + if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) { + byte[] newCrypto = new byte[1024]; - routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> { + new Random().nextBytes(newCrypto); + + String base64 = Base64.getEncoder().encodeToString(newCrypto); + System.setProperty("play.crypto.secret", base64); + } + + + server = Server.forRouter(Mode.PROD, port, this::createRouter); + } + + protected Router createRouter(BuiltInComponents b){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(b); + + routingDsl.GET("/transformprocess").routingTo(req -> { try { if (transform == null) return badRequest(); @@ -88,11 +107,11 @@ public class CSVSparkTransformServer extends SparkTransformServer { log.error("Error in GET /transformprocess",e); return internalServerError(e.getMessage()); } - }))); + }); - routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformprocess").routingTo(req -> { try { - TransformProcess transformProcess = TransformProcess.fromJson(getJsonText()); + TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req)); setCSVTransformProcess(transformProcess); log.info("Transform process initialized"); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); @@ -100,12 +119,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { log.error("Error in POST /transformprocess",e); return internalServerError(e.getMessage()); } - }))); + }); - routingDsl.POST("/transformincremental").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transformincremental").routingTo(req -> { + if (isSequence(req)) { try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType); @@ -115,7 +134,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class); + SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType); @@ -124,12 +143,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } - }))); + }); - routingDsl.POST("/transform").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transform").routingTo(req -> { + if (isSequence(req)) { try { - SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class)); + SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class)); if (batch == null) return badRequest(); return ok(objectMapper.writeValueAsString(batch)).as(contentType); @@ -139,7 +158,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - BatchCSVRecord input = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); BatchCSVRecord batch = transform(input); if (batch == null) return badRequest(); @@ -149,14 +168,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } + }); - - }))); - - routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transformincrementalarray").routingTo(req -> { + if (isSequence(req)) { try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType); @@ -166,7 +183,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class); + SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType); @@ -175,13 +192,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } + }); - }))); - - routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transformarray").routingTo(req -> { + if (isSequence(req)) { try { - SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class); + SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class); if (batchCSVRecord == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType); @@ -191,7 +207,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); if (batchCSVRecord == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType); @@ -200,10 +216,9 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } - }))); + }); - - server = Server.forRouter(routingDsl.build(), Mode.PROD, port); + return routingDsl.build(); } public static void main(String[] args) throws Exception { diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java deleted file mode 100644 index 6c4874b02..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import play.libs.F; -import play.mvc.Result; - -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * Utility methods for Routing - * - * @author Alex Black - */ -public class FunctionUtil { - - - public static F.Function0 function0(Supplier supplier) { - return supplier::get; - } - - public static F.Function function(Function function) { - return function::apply; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java index 29c1e1bd7..f8675f139 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java @@ -24,8 +24,11 @@ import org.apache.commons.io.FileUtils; import org.datavec.api.transform.TransformProcess; import org.datavec.image.transform.ImageTransformProcess; import org.datavec.spark.transform.model.*; +import play.BuiltInComponents; import play.Mode; +import play.libs.Files; import play.mvc.Http; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; @@ -33,6 +36,7 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import static play.mvc.Controller.request; import static play.mvc.Results.*; @@ -62,8 +66,6 @@ public class ImageSparkTransformServer extends SparkTransformServer { System.exit(1); } - RoutingDsl routingDsl = new RoutingDsl(); - if (jsonPath != null) { String json = FileUtils.readFileToString(new File(jsonPath)); ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json); @@ -73,7 +75,13 @@ public class ImageSparkTransformServer extends SparkTransformServer { + "to /transformprocess"); } - routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> { + server = Server.forRouter(Mode.PROD, port, this::createRouter); + } + + protected Router createRouter(BuiltInComponents builtInComponents){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); + + routingDsl.GET("/transformprocess").routingTo(req -> { try { if (transform == null) return badRequest(); @@ -83,11 +91,11 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformprocess").routingTo(req -> { try { - ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText()); + ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req)); setImageTransformProcess(transformProcess); log.info("Transform process initialized"); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); @@ -95,11 +103,11 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformincrementalarray").routingTo(req -> { try { - SingleImageRecord record = objectMapper.readValue(getJsonText(), SingleImageRecord.class); + SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); @@ -107,17 +115,17 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformincrementalimage").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformincrementalimage").routingTo(req -> { try { - Http.MultipartFormData body = request().body().asMultipartFormData(); - List files = body.getFiles(); - if (files.size() == 0 || files.get(0).getFile() == null) { + Http.MultipartFormData body = req.body().asMultipartFormData(); + List> files = body.getFiles(); + if (files.isEmpty() || files.get(0).getRef() == null ) { return badRequest(); } - File file = files.get(0).getFile(); + File file = files.get(0).getRef().path().toFile(); SingleImageRecord record = new SingleImageRecord(file.toURI()); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); @@ -125,11 +133,11 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformarray").routingTo(req -> { try { - BatchImageRecord batch = objectMapper.readValue(getJsonText(), BatchImageRecord.class); + BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class); if (batch == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); @@ -137,22 +145,22 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformimage").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformimage").routingTo(req -> { try { - Http.MultipartFormData body = request().body().asMultipartFormData(); - List files = body.getFiles(); + Http.MultipartFormData body = req.body().asMultipartFormData(); + List> files = body.getFiles(); if (files.size() == 0) { return badRequest(); } List records = new ArrayList<>(); - for (Http.MultipartFormData.FilePart filePart : files) { - File file = filePart.getFile(); + for (Http.MultipartFormData.FilePart filePart : files) { + Files.TemporaryFile file = filePart.getRef(); if (file != null) { - SingleImageRecord record = new SingleImageRecord(file.toURI()); + SingleImageRecord record = new SingleImageRecord(file.path().toUri()); records.add(record); } } @@ -164,9 +172,9 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - server = Server.forRouter(routingDsl.build(), Mode.PROD, port); + return routingDsl.build(); } @Override diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java index 2d4c92836..411872006 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java @@ -22,6 +22,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody; import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.service.DataVecTransformService; import org.nd4j.shade.jackson.databind.ObjectMapper; +import play.mvc.Http; import play.server.Server; import static play.mvc.Controller.request; @@ -50,25 +51,17 @@ public abstract class SparkTransformServer implements DataVecTransformService { server.stop(); } - protected boolean isSequence() { - return request().hasHeader(SEQUENCE_OR_NOT_HEADER) - && request().getHeader(SEQUENCE_OR_NOT_HEADER).toUpperCase() - .equals("TRUE"); + protected boolean isSequence(Http.Request request) { + return request.hasHeader(SEQUENCE_OR_NOT_HEADER) + && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true"); } - - protected String getHeaderValue(String value) { - if (request().hasHeader(value)) - return request().getHeader(value); - return null; - } - - protected String getJsonText() { - JsonNode tryJson = request().body().asJson(); + protected String getJsonText(Http.Request request) { + JsonNode tryJson = request.body().asJson(); if (tryJson != null) return tryJson.toString(); else - return request().body().asText(); + return request.body().asText(); } public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf new file mode 100644 index 000000000..28a4aa208 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf @@ -0,0 +1,350 @@ +# This is the main configuration file for the application. +# https://www.playframework.com/documentation/latest/ConfigFile +# ~~~~~ +# Play uses HOCON as its configuration file format. HOCON has a number +# of advantages over other config formats, but there are two things that +# can be used when modifying settings. +# +# You can include other configuration files in this main application.conf file: +#include "extra-config.conf" +# +# You can declare variables and substitute for them: +#mykey = ${some.value} +# +# And if an environment variable exists when there is no other subsitution, then +# HOCON will fall back to substituting environment variable: +#mykey = ${JAVA_HOME} + +## Akka +# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration +# https://www.playframework.com/documentation/latest/JavaAkka#Configuration +# ~~~~~ +# Play uses Akka internally and exposes Akka Streams and actors in Websockets and +# other streaming HTTP responses. +akka { + # "akka.log-config-on-start" is extraordinarly useful because it log the complete + # configuration at INFO level, including defaults and overrides, so it s worth + # putting at the very top. + # + # Put the following in your conf/logback.xml file: + # + # + # + # And then uncomment this line to debug the configuration. + # + #log-config-on-start = true +} + +## Modules +# https://www.playframework.com/documentation/latest/Modules +# ~~~~~ +# Control which modules are loaded when Play starts. Note that modules are +# the replacement for "GlobalSettings", which are deprecated in 2.5.x. +# Please see https://www.playframework.com/documentation/latest/GlobalSettings +# for more information. +# +# You can also extend Play functionality by using one of the publically available +# Play modules: https://playframework.com/documentation/latest/ModuleDirectory +play.modules { + # By default, Play will load any class called Module that is defined + # in the root package (the "app" directory), or you can define them + # explicitly below. + # If there are any built-in modules that you want to disable, you can list them here. + #enabled += my.application.Module + + # If there are any built-in modules that you want to disable, you can list them here. + #disabled += "" +} + +## Internationalisation +# https://www.playframework.com/documentation/latest/JavaI18N +# https://www.playframework.com/documentation/latest/ScalaI18N +# ~~~~~ +# Play comes with its own i18n settings, which allow the user's preferred language +# to map through to internal messages, or allow the language to be stored in a cookie. +play.i18n { + # The application languages + langs = [ "en" ] + + # Whether the language cookie should be secure or not + #langCookieSecure = true + + # Whether the HTTP only attribute of the cookie should be set to true + #langCookieHttpOnly = true +} + +## Play HTTP settings +# ~~~~~ +play.http { + ## Router + # https://www.playframework.com/documentation/latest/JavaRouting + # https://www.playframework.com/documentation/latest/ScalaRouting + # ~~~~~ + # Define the Router object to use for this application. + # This router will be looked up first when the application is starting up, + # so make sure this is the entry point. + # Furthermore, it's assumed your route file is named properly. + # So for an application router like `my.application.Router`, + # you may need to define a router file `conf/my.application.routes`. + # Default to Routes in the root package (aka "apps" folder) (and conf/routes) + #router = my.application.Router + + ## Action Creator + # https://www.playframework.com/documentation/latest/JavaActionCreator + # ~~~~~ + #actionCreator = null + + ## ErrorHandler + # https://www.playframework.com/documentation/latest/JavaRouting + # https://www.playframework.com/documentation/latest/ScalaRouting + # ~~~~~ + # If null, will attempt to load a class called ErrorHandler in the root package, + #errorHandler = null + + ## Filters + # https://www.playframework.com/documentation/latest/ScalaHttpFilters + # https://www.playframework.com/documentation/latest/JavaHttpFilters + # ~~~~~ + # Filters run code on every request. They can be used to perform + # common logic for all your actions, e.g. adding common headers. + # Defaults to "Filters" in the root package (aka "apps" folder) + # Alternatively you can explicitly register a class here. + #filters += my.application.Filters + + ## Session & Flash + # https://www.playframework.com/documentation/latest/JavaSessionFlash + # https://www.playframework.com/documentation/latest/ScalaSessionFlash + # ~~~~~ + session { + # Sets the cookie to be sent only over HTTPS. + #secure = true + + # Sets the cookie to be accessed only by the server. + #httpOnly = true + + # Sets the max-age field of the cookie to 5 minutes. + # NOTE: this only sets when the browser will discard the cookie. Play will consider any + # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout, + # you need to put a timestamp in the session and check it at regular intervals to possibly expire it. + #maxAge = 300 + + # Sets the domain on the session cookie. + #domain = "example.com" + } + + flash { + # Sets the cookie to be sent only over HTTPS. + #secure = true + + # Sets the cookie to be accessed only by the server. + #httpOnly = true + } +} + +## Netty Provider +# https://www.playframework.com/documentation/latest/SettingsNetty +# ~~~~~ +play.server.netty { + # Whether the Netty wire should be logged + #log.wire = true + + # If you run Play on Linux, you can use Netty's native socket transport + # for higher performance with less garbage. + #transport = "native" +} + +## WS (HTTP Client) +# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS +# ~~~~~ +# The HTTP client primarily used for REST APIs. The default client can be +# configured directly, but you can also create different client instances +# with customized settings. You must enable this by adding to build.sbt: +# +# libraryDependencies += ws // or javaWs if using java +# +play.ws { + # Sets HTTP requests not to follow 302 requests + #followRedirects = false + + # Sets the maximum number of open HTTP connections for the client. + #ahc.maxConnectionsTotal = 50 + + ## WS SSL + # https://www.playframework.com/documentation/latest/WsSSL + # ~~~~~ + ssl { + # Configuring HTTPS with Play WS does not require programming. You can + # set up both trustManager and keyManager for mutual authentication, and + # turn on JSSE debugging in development with a reload. + #debug.handshake = true + #trustManager = { + # stores = [ + # { type = "JKS", path = "exampletrust.jks" } + # ] + #} + } +} + +## Cache +# https://www.playframework.com/documentation/latest/JavaCache +# https://www.playframework.com/documentation/latest/ScalaCache +# ~~~~~ +# Play comes with an integrated cache API that can reduce the operational +# overhead of repeated requests. You must enable this by adding to build.sbt: +# +# libraryDependencies += cache +# +play.cache { + # If you want to bind several caches, you can bind the individually + #bindCaches = ["db-cache", "user-cache", "session-cache"] +} + +## Filters +# https://www.playframework.com/documentation/latest/Filters +# ~~~~~ +# There are a number of built-in filters that can be enabled and configured +# to give Play greater security. You must enable this by adding to build.sbt: +# +# libraryDependencies += filters +# +play.filters { + ## CORS filter configuration + # https://www.playframework.com/documentation/latest/CorsFilter + # ~~~~~ + # CORS is a protocol that allows web applications to make requests from the browser + # across different domains. + # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has + # dependencies on CORS settings. + cors { + # Filter paths by a whitelist of path prefixes + #pathPrefixes = ["/some/path", ...] + + # The allowed origins. If null, all origins are allowed. + #allowedOrigins = ["http://www.example.com"] + + # The allowed HTTP methods. If null, all methods are allowed + #allowedHttpMethods = ["GET", "POST"] + } + + ## CSRF Filter + # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter + # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter + # ~~~~~ + # Play supports multiple methods for verifying that a request is not a CSRF request. + # The primary mechanism is a CSRF token. This token gets placed either in the query string + # or body of every form submitted, and also gets placed in the users session. + # Play then verifies that both tokens are present and match. + csrf { + # Sets the cookie to be sent only over HTTPS + #cookie.secure = true + + # Defaults to CSRFErrorHandler in the root package. + #errorHandler = MyCSRFErrorHandler + } + + ## Security headers filter configuration + # https://www.playframework.com/documentation/latest/SecurityHeaders + # ~~~~~ + # Defines security headers that prevent XSS attacks. + # If enabled, then all options are set to the below configuration by default: + headers { + # The X-Frame-Options header. If null, the header is not set. + #frameOptions = "DENY" + + # The X-XSS-Protection header. If null, the header is not set. + #xssProtection = "1; mode=block" + + # The X-Content-Type-Options header. If null, the header is not set. + #contentTypeOptions = "nosniff" + + # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set. + #permittedCrossDomainPolicies = "master-only" + + # The Content-Security-Policy header. If null, the header is not set. + #contentSecurityPolicy = "default-src 'self'" + } + + ## Allowed hosts filter configuration + # https://www.playframework.com/documentation/latest/AllowedHostsFilter + # ~~~~~ + # Play provides a filter that lets you configure which hosts can access your application. + # This is useful to prevent cache poisoning attacks. + hosts { + # Allow requests to example.com, its subdomains, and localhost:9000. + #allowed = [".example.com", "localhost:9000"] + } +} + +## Evolutions +# https://www.playframework.com/documentation/latest/Evolutions +# ~~~~~ +# Evolutions allows database scripts to be automatically run on startup in dev mode +# for database migrations. You must enable this by adding to build.sbt: +# +# libraryDependencies += evolutions +# +play.evolutions { + # You can disable evolutions for a specific datasource if necessary + #db.default.enabled = false +} + +## Database Connection Pool +# https://www.playframework.com/documentation/latest/SettingsJDBC +# ~~~~~ +# Play doesn't require a JDBC database to run, but you can easily enable one. +# +# libraryDependencies += jdbc +# +play.db { + # The combination of these two settings results in "db.default" as the + # default JDBC pool: + #config = "db" + #default = "default" + + # Play uses HikariCP as the default connection pool. You can override + # settings by changing the prototype: + prototype { + # Sets a fixed JDBC connection pool size of 50 + #hikaricp.minimumIdle = 50 + #hikaricp.maximumPoolSize = 50 + } +} + +## JDBC Datasource +# https://www.playframework.com/documentation/latest/JavaDatabase +# https://www.playframework.com/documentation/latest/ScalaDatabase +# ~~~~~ +# Once JDBC datasource is set up, you can work with several different +# database options: +# +# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick +# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA +# EBean: https://playframework.com/documentation/latest/JavaEbean +# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm +# +db { + # You can declare as many datasources as you want. + # By convention, the default datasource is named `default` + + # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database + default.driver = org.h2.Driver + default.url = "jdbc:h2:mem:play" + #default.username = sa + #default.password = "" + + # You can expose this datasource via JNDI if needed (Useful for JPA) + default.jndiName=DefaultDS + + # You can turn on SQL logging for any datasource + # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements + #default.logSql=true +} + +jpa.default=defaultPersistenceUnit + + +#Increase default maximum post length - used for remote listener functionality +#Can get response 413 with larger networks without setting this +# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead +#parsers.text.maxLength=10M +play.http.parser.maxMemoryBuffer=10M diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index d98730407..05c505cac 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -28,61 +28,11 @@ datavec-spark_2.11 - - 2.1.0 - 2 - 2.11.12 2.11 - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-source - generate-sources - - add-source - - - - src/main/spark-${spark.major.version} - - - - - - - - - - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${jackson.version} - - - com.fasterxml.jackson.module - jackson-module-scala_2.11 - ${jackson.version} - - - - org.scala-lang @@ -95,42 +45,13 @@ ${scala.version} - - org.codehaus.jackson - jackson-core-asl - ${jackson-asl.version} - - - org.codehaus.jackson - jackson-mapper-asl - ${jackson-asl.version} - org.apache.spark spark-sql_2.11 ${spark.version} + provided - - com.google.guava - guava - ${guava.version} - - - com.google.inject - guice - ${guice.version} - - - com.google.protobuf - protobuf-java - ${google.protobuf.version} - - - commons-codec - commons-codec - ${commons-codec.version} - commons-collections commons-collections @@ -141,96 +62,16 @@ commons-io ${commons-io.version} - - commons-lang - commons-lang - ${commons-lang.version} - - - commons-net - commons-net - ${commons-net.version} - - - com.sun.xml.bind - jaxb-core - ${jaxb.version} - - - com.sun.xml.bind - jaxb-impl - ${jaxb.version} - - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-remote_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - - - io.netty - netty - ${netty.version} - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - javax.servlet - javax.servlet-api - ${servlet.version} - - - org.apache.commons - commons-compress - ${commons-compress.version} - - - org.apache.commons - commons-lang3 - ${commons-lang3.version} - org.apache.commons commons-math3 ${commons-math3.version} - - org.apache.curator - curator-recipes - ${curator.version} - org.slf4j slf4j-api ${slf4j.version} - - com.typesafe - config - ${typesafe.config.version} - org.apache.spark spark-core_2.11 @@ -241,14 +82,6 @@ com.google.code.findbugs jsr305 - - org.slf4j - slf4j-log4j12 - - - log4j - log4j - diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java deleted file mode 100644 index 8aeae58a5..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java +++ /dev/null @@ -1,29 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.functions; - -import java.io.Serializable; - -/** - * - * A function that returns zero or more output records from each input record. - * - * Adapter for Spark interface in order to freeze interface changes between spark versions - */ -public interface FlatMapFunctionAdapter extends Serializable { - Iterable call(T t) throws Exception; -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java index fbe8a63d4..5d0bff7f3 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java @@ -21,10 +21,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; -import org.apache.spark.sql.Column; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.functions; +import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -46,7 +43,6 @@ import java.util.List; import static org.apache.spark.sql.functions.avg; import static org.apache.spark.sql.functions.col; -import static org.datavec.spark.transform.DataRowsFacade.dataRows; /** @@ -71,7 +67,7 @@ public class DataFrames { * deviation for * @return the column that represents the standard deviation */ - public static Column std(DataRowsFacade dataFrame, String columnName) { + public static Column std(Dataset dataFrame, String columnName) { return functions.sqrt(var(dataFrame, columnName)); } @@ -85,8 +81,8 @@ public class DataFrames { * deviation for * @return the column that represents the standard deviation */ - public static Column var(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(functions.variance(columnName)).col(columnName); + public static Column var(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(functions.variance(columnName)).col(columnName); } /** @@ -97,8 +93,8 @@ public class DataFrames { * @param columnName the name of the column to get the min for * @return the column that represents the min */ - public static Column min(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(functions.min(columnName)).col(columnName); + public static Column min(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(functions.min(columnName)).col(columnName); } /** @@ -110,8 +106,8 @@ public class DataFrames { * to get the max for * @return the column that represents the max */ - public static Column max(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(functions.max(columnName)).col(columnName); + public static Column max(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(functions.max(columnName)).col(columnName); } /** @@ -122,8 +118,8 @@ public class DataFrames { * @param columnName the name of the column to get the mean for * @return the column that represents the mean */ - public static Column mean(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(avg(columnName)).col(columnName); + public static Column mean(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(avg(columnName)).col(columnName); } /** @@ -166,7 +162,7 @@ public class DataFrames { * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position * of this record in the original time series.
* These two columns are required if the data is to be converted back into a sequence at a later point, for example - * using {@link #toRecordsSequence(DataRowsFacade)} + * using {@link #toRecordsSequence(Dataset)} * * @param schema Schema to convert * @return StructType for the schema @@ -250,9 +246,9 @@ public class DataFrames { * @param dataFrame the dataframe to convert * @return the converted schema and rdd of writables */ - public static Pair>> toRecords(DataRowsFacade dataFrame) { - Schema schema = fromStructType(dataFrame.get().schema()); - return new Pair<>(schema, dataFrame.get().javaRDD().map(new ToRecord(schema))); + public static Pair>> toRecords(Dataset dataFrame) { + Schema schema = fromStructType(dataFrame.schema()); + return new Pair<>(schema, dataFrame.javaRDD().map(new ToRecord(schema))); } /** @@ -267,11 +263,11 @@ public class DataFrames { * @param dataFrame Data frame to convert * @return Data in sequence (i.e., {@code List>} form */ - public static Pair>>> toRecordsSequence(DataRowsFacade dataFrame) { + public static Pair>>> toRecordsSequence(Dataset dataFrame) { //Need to convert from flattened to sequence data... //First: Group by the Sequence UUID (first column) - JavaPairRDD> grouped = dataFrame.get().javaRDD().groupBy(new Function() { + JavaPairRDD> grouped = dataFrame.javaRDD().groupBy(new Function() { @Override public String call(Row row) throws Exception { return row.getString(0); @@ -279,7 +275,7 @@ public class DataFrames { }); - Schema schema = fromStructType(dataFrame.get().schema()); + Schema schema = fromStructType(dataFrame.schema()); //Group by sequence UUID, and sort each row within the sequences using the time step index Function, List>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner @@ -318,11 +314,11 @@ public class DataFrames { * @param data the data to convert * @return the dataframe object */ - public static DataRowsFacade toDataFrame(Schema schema, JavaRDD> data) { + public static Dataset toDataFrame(Schema schema, JavaRDD> data) { JavaSparkContext sc = new JavaSparkContext(data.context()); SQLContext sqlContext = new SQLContext(sc); JavaRDD rows = data.map(new ToRow(schema)); - return dataRows(sqlContext.createDataFrame(rows, fromSchema(schema))); + return sqlContext.createDataFrame(rows, fromSchema(schema)); } @@ -333,18 +329,18 @@ public class DataFrames { * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position * of this record in the original time series.
* These two columns are required if the data is to be converted back into a sequence at a later point, for example - * using {@link #toRecordsSequence(DataRowsFacade)} + * using {@link #toRecordsSequence(Dataset)} * * @param schema Schema for the data * @param data Sequence data to convert to a DataFrame * @return The dataframe object */ - public static DataRowsFacade toDataFrameSequence(Schema schema, JavaRDD>> data) { + public static Dataset toDataFrameSequence(Schema schema, JavaRDD>> data) { JavaSparkContext sc = new JavaSparkContext(data.context()); SQLContext sqlContext = new SQLContext(sc); JavaRDD rows = data.flatMap(new SequenceToRows(schema)); - return dataRows(sqlContext.createDataFrame(rows, fromSchemaSequence(schema))); + return sqlContext.createDataFrame(rows, fromSchemaSequence(schema)); } /** diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java index 68efd1888..cacea101d 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java @@ -19,14 +19,13 @@ package org.datavec.spark.transform; import org.apache.commons.collections.map.ListOrderedMap; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Writable; import java.util.*; -import static org.datavec.spark.transform.DataRowsFacade.dataRows; - /** * Simple dataframe based normalization. @@ -46,7 +45,7 @@ public class Normalization { * @return a zero mean unit variance centered * rdd */ - public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame) { + public static Dataset zeromeanUnitVariance(Dataset frame) { return zeromeanUnitVariance(frame, Collections.emptyList()); } @@ -71,7 +70,7 @@ public class Normalization { * @param max the maximum value * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max) { + public static Dataset normalize(Dataset dataFrame, double min, double max) { return normalize(dataFrame, min, max, Collections.emptyList()); } @@ -86,7 +85,7 @@ public class Normalization { */ public static JavaRDD> normalize(Schema schema, JavaRDD> data, double min, double max) { - DataRowsFacade frame = DataFrames.toDataFrame(schema, data); + Dataset frame = DataFrames.toDataFrame(schema, data); return DataFrames.toRecords(normalize(frame, min, max, Collections.emptyList())).getSecond(); } @@ -97,7 +96,7 @@ public class Normalization { * @param dataFrame the dataframe to scale * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame) { + public static Dataset normalize(Dataset dataFrame) { return normalize(dataFrame, 0, 1, Collections.emptyList()); } @@ -120,8 +119,8 @@ public class Normalization { * @return a zero mean unit variance centered * rdd */ - public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame, List skipColumns) { - List columnsList = DataFrames.toList(frame.get().columns()); + public static Dataset zeromeanUnitVariance(Dataset frame, List skipColumns) { + List columnsList = DataFrames.toList(frame.columns()); columnsList.removeAll(skipColumns); String[] columnNames = DataFrames.toArray(columnsList); //first row is std second row is mean, each column in a row is for a particular column @@ -133,7 +132,7 @@ public class Normalization { if (std == 0.0) std = 1; //All same value -> (x-x)/1 = 0 - frame = dataRows(frame.get().withColumn(columnName, frame.get().col(columnName).minus(mean).divide(std))); + frame = frame.withColumn(columnName, frame.col(columnName).minus(mean).divide(std)); } @@ -152,7 +151,7 @@ public class Normalization { */ public static JavaRDD> zeromeanUnitVariance(Schema schema, JavaRDD> data, List skipColumns) { - DataRowsFacade frame = DataFrames.toDataFrame(schema, data); + Dataset frame = DataFrames.toDataFrame(schema, data); return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond(); } @@ -178,7 +177,7 @@ public class Normalization { */ public static JavaRDD>> zeroMeanUnitVarianceSequence(Schema schema, JavaRDD>> sequence, List excludeColumns) { - DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, sequence); + Dataset frame = DataFrames.toDataFrameSequence(schema, sequence); if (excludeColumns == null) excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN); else { @@ -196,7 +195,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List minMaxColumns(DataRowsFacade data, List columns) { + public static List minMaxColumns(Dataset data, List columns) { String[] arr = new String[columns.size()]; for (int i = 0; i < arr.length; i++) arr[i] = columns.get(i); @@ -210,7 +209,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List minMaxColumns(DataRowsFacade data, String... columns) { + public static List minMaxColumns(Dataset data, String... columns) { return aggregate(data, columns, new String[] {"min", "max"}); } @@ -221,7 +220,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List stdDevMeanColumns(DataRowsFacade data, List columns) { + public static List stdDevMeanColumns(Dataset data, List columns) { String[] arr = new String[columns.size()]; for (int i = 0; i < arr.length; i++) arr[i] = columns.get(i); @@ -237,7 +236,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List stdDevMeanColumns(DataRowsFacade data, String... columns) { + public static List stdDevMeanColumns(Dataset data, String... columns) { return aggregate(data, columns, new String[] {"stddev", "mean"}); } @@ -251,7 +250,7 @@ public class Normalization { * Each row will be a function with the desired columnar output * in the order in which the columns were specified. */ - public static List aggregate(DataRowsFacade data, String[] columns, String[] functions) { + public static List aggregate(Dataset data, String[] columns, String[] functions) { String[] rest = new String[columns.length - 1]; System.arraycopy(columns, 1, rest, 0, rest.length); List rows = new ArrayList<>(); @@ -262,8 +261,8 @@ public class Normalization { } //compute the aggregation based on the operation - DataRowsFacade aggregated = dataRows(data.get().agg(expressions)); - String[] columns2 = aggregated.get().columns(); + Dataset aggregated = data.agg(expressions); + String[] columns2 = aggregated.columns(); //strip out the op name and parentheses from the columns Map opReplace = new TreeMap<>(); for (String s : columns2) { @@ -278,20 +277,20 @@ public class Normalization { //get rid of the operation name in the column - DataRowsFacade rearranged = null; + Dataset rearranged = null; for (Map.Entry entries : opReplace.entrySet()) { //first column if (rearranged == null) { - rearranged = dataRows(aggregated.get().withColumnRenamed(entries.getKey(), entries.getValue())); + rearranged = aggregated.withColumnRenamed(entries.getKey(), entries.getValue()); } //rearranged is just a copy of aggregated at this point else - rearranged = dataRows(rearranged.get().withColumnRenamed(entries.getKey(), entries.getValue())); + rearranged = rearranged.withColumnRenamed(entries.getKey(), entries.getValue()); } - rearranged = dataRows(rearranged.get().select(DataFrames.toColumns(columns))); + rearranged = rearranged.select(DataFrames.toColumns(columns)); //op - rows.addAll(rearranged.get().collectAsList()); + rows.addAll(rearranged.collectAsList()); } @@ -307,8 +306,8 @@ public class Normalization { * @param max the maximum value * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max, List skipColumns) { - List columnsList = DataFrames.toList(dataFrame.get().columns()); + public static Dataset normalize(Dataset dataFrame, double min, double max, List skipColumns) { + List columnsList = DataFrames.toList(dataFrame.columns()); columnsList.removeAll(skipColumns); String[] columnNames = DataFrames.toArray(columnsList); //first row is min second row is max, each column in a row is for a particular column @@ -321,8 +320,8 @@ public class Normalization { if (maxSubMin == 0) maxSubMin = 1; - Column newCol = dataFrame.get().col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min); - dataFrame = dataRows(dataFrame.get().withColumn(columnName, newCol)); + Column newCol = dataFrame.col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min); + dataFrame = dataFrame.withColumn(columnName, newCol); } @@ -340,7 +339,7 @@ public class Normalization { */ public static JavaRDD> normalize(Schema schema, JavaRDD> data, double min, double max, List skipColumns) { - DataRowsFacade frame = DataFrames.toDataFrame(schema, data); + Dataset frame = DataFrames.toDataFrame(schema, data); return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond(); } @@ -387,7 +386,7 @@ public class Normalization { excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN); excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN); } - DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, data); + Dataset frame = DataFrames.toDataFrameSequence(schema, data); return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond(); } @@ -398,7 +397,7 @@ public class Normalization { * @param dataFrame the dataframe to scale * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame, List skipColumns) { + public static Dataset normalize(Dataset dataFrame, List skipColumns) { return normalize(dataFrame, 0, 1, skipColumns); } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java index 6b7ff203f..5052491bb 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java @@ -16,9 +16,10 @@ package org.datavec.spark.transform.analysis; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Iterator; import java.util.List; /** @@ -27,10 +28,11 @@ import java.util.List; * * @author Alex Black */ -public class SequenceFlatMapFunction extends BaseFlatMapFunctionAdaptee>, List> { +public class SequenceFlatMapFunction implements FlatMapFunction>, List> { - public SequenceFlatMapFunction() { - super(new SequenceFlatMapFunctionAdapter()); + @Override + public Iterator> call(List> collections) throws Exception { + return collections.iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java deleted file mode 100644 index 6b25fb826..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.analysis; - -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.List; - -/** - * SequenceFlatMapFunction: very simple function used to flatten a sequence - * Typically used only internally for certain analysis operations - * - * @author Alex Black - */ -public class SequenceFlatMapFunctionAdapter implements FlatMapFunctionAdapter>, List> { - @Override - public Iterable> call(List> collections) throws Exception { - return collections; - } - -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java index 52bf924be..6e501f560 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java @@ -16,11 +16,14 @@ package org.datavec.spark.transform.join; +import org.nd4j.shade.guava.collect.Iterables; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.join.Join; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Iterator; import java.util.List; /** @@ -28,10 +31,89 @@ import java.util.List; * * @author Alex Black */ -public class ExecuteJoinFromCoGroupFlatMapFunction extends - BaseFlatMapFunctionAdaptee, Tuple2>, Iterable>>>, List> { +public class ExecuteJoinFromCoGroupFlatMapFunction implements FlatMapFunction, Tuple2>, Iterable>>>, List> { + + private final Join join; public ExecuteJoinFromCoGroupFlatMapFunction(Join join) { - super(new ExecuteJoinFromCoGroupFlatMapFunctionAdapter(join)); + this.join = join; + } + + @Override + public Iterator> call( + Tuple2, Tuple2>, Iterable>>> t2) + throws Exception { + + Iterable> leftList = t2._2()._1(); + Iterable> rightList = t2._2()._2(); + + List> ret = new ArrayList<>(); + Join.JoinType jt = join.getJoinType(); + switch (jt) { + case Inner: + //Return records where key columns appear in BOTH + //So if no values from left OR right: no return values + for (List jvl : leftList) { + for (List jvr : rightList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + break; + case LeftOuter: + //Return all records from left, even if no corresponding right value (NullWritable in that case) + for (List jvl : leftList) { + if (Iterables.size(rightList) == 0) { + List joined = join.joinExamples(jvl, null); + ret.add(joined); + } else { + for (List jvr : rightList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + } + break; + case RightOuter: + //Return all records from right, even if no corresponding left value (NullWritable in that case) + for (List jvr : rightList) { + if (Iterables.size(leftList) == 0) { + List joined = join.joinExamples(null, jvr); + ret.add(joined); + } else { + for (List jvl : leftList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + } + break; + case FullOuter: + //Return all records, even if no corresponding left/right value (NullWritable in that case) + if (Iterables.size(leftList) == 0) { + //Only right values + for (List jvr : rightList) { + List joined = join.joinExamples(null, jvr); + ret.add(joined); + } + } else if (Iterables.size(rightList) == 0) { + //Only left values + for (List jvl : leftList) { + List joined = join.joinExamples(jvl, null); + ret.add(joined); + } + } else { + //Records from both left and right + for (List jvl : leftList) { + for (List jvr : rightList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + } + break; + } + + return ret.iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java deleted file mode 100644 index dedff46d0..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java +++ /dev/null @@ -1,119 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.join; - -import com.google.common.collect.Iterables; -import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import scala.Tuple2; - -import java.util.ArrayList; -import java.util.List; - -/** - * Execute a join - * - * @author Alex Black - */ -public class ExecuteJoinFromCoGroupFlatMapFunctionAdapter implements - FlatMapFunctionAdapter, Tuple2>, Iterable>>>, List> { - - private final Join join; - - public ExecuteJoinFromCoGroupFlatMapFunctionAdapter(Join join) { - this.join = join; - } - - @Override - public Iterable> call( - Tuple2, Tuple2>, Iterable>>> t2) - throws Exception { - - Iterable> leftList = t2._2()._1(); - Iterable> rightList = t2._2()._2(); - - List> ret = new ArrayList<>(); - Join.JoinType jt = join.getJoinType(); - switch (jt) { - case Inner: - //Return records where key columns appear in BOTH - //So if no values from left OR right: no return values - for (List jvl : leftList) { - for (List jvr : rightList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - break; - case LeftOuter: - //Return all records from left, even if no corresponding right value (NullWritable in that case) - for (List jvl : leftList) { - if (Iterables.size(rightList) == 0) { - List joined = join.joinExamples(jvl, null); - ret.add(joined); - } else { - for (List jvr : rightList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - } - break; - case RightOuter: - //Return all records from right, even if no corresponding left value (NullWritable in that case) - for (List jvr : rightList) { - if (Iterables.size(leftList) == 0) { - List joined = join.joinExamples(null, jvr); - ret.add(joined); - } else { - for (List jvl : leftList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - } - break; - case FullOuter: - //Return all records, even if no corresponding left/right value (NullWritable in that case) - if (Iterables.size(leftList) == 0) { - //Only right values - for (List jvr : rightList) { - List joined = join.joinExamples(null, jvr); - ret.add(joined); - } - } else if (Iterables.size(rightList) == 0) { - //Only left values - for (List jvl : leftList) { - List joined = join.joinExamples(jvl, null); - ret.add(joined); - } - } else { - //Records from both left and right - for (List jvl : leftList) { - for (List jvr : rightList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - } - break; - } - - return ret; - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java index 6e206a657..d4ede4808 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java @@ -16,10 +16,12 @@ package org.datavec.spark.transform.join; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.join.Join; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Collections; +import java.util.Iterator; import java.util.List; /** @@ -29,10 +31,43 @@ import java.util.List; * * @author Alex Black */ -public class FilterAndFlattenJoinedValues extends BaseFlatMapFunctionAdaptee> { +public class FilterAndFlattenJoinedValues implements FlatMapFunction> { + + private final Join.JoinType joinType; public FilterAndFlattenJoinedValues(Join.JoinType joinType) { - super(new FilterAndFlattenJoinedValuesAdapter(joinType)); + this.joinType = joinType; + } + + @Override + public Iterator> call(JoinedValue joinedValue) throws Exception { + boolean keep; + switch (joinType) { + case Inner: + //Only keep joined values where we have both left and right + keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight(); + break; + case LeftOuter: + //Keep all values where left is not missing/null + keep = joinedValue.isHaveLeft(); + break; + case RightOuter: + //Keep all values where right is not missing/null + keep = joinedValue.isHaveRight(); + break; + case FullOuter: + //Keep all values + keep = true; + break; + default: + throw new RuntimeException("Unknown/not implemented join type: " + joinType); + } + + if (keep) { + return Collections.singletonList(joinedValue.getValues()).iterator(); + } else { + return Collections.emptyIterator(); + } } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java deleted file mode 100644 index 3333276b1..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java +++ /dev/null @@ -1,71 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.join; - -import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Collections; -import java.util.List; - -/** - * Doing two things here: - * (a) filter out any unnecessary values, and - * (b) extract the List values from the JoinedValue - * - * @author Alex Black - */ -public class FilterAndFlattenJoinedValuesAdapter implements FlatMapFunctionAdapter> { - - private final Join.JoinType joinType; - - public FilterAndFlattenJoinedValuesAdapter(Join.JoinType joinType) { - this.joinType = joinType; - } - - @Override - public Iterable> call(JoinedValue joinedValue) throws Exception { - boolean keep; - switch (joinType) { - case Inner: - //Only keep joined values where we have both left and right - keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight(); - break; - case LeftOuter: - //Keep all values where left is not missing/null - keep = joinedValue.isHaveLeft(); - break; - case RightOuter: - //Keep all values where right is not missing/null - keep = joinedValue.isHaveRight(); - break; - case FullOuter: - //Keep all values - keep = true; - break; - default: - throw new RuntimeException("Unknown/not implemented join type: " + joinType); - } - - if (keep) { - return Collections.singletonList(joinedValue.getValues()); - } else { - return Collections.emptyList(); - } - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java index 28c91c84b..639e43836 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java @@ -16,21 +16,69 @@ package org.datavec.spark.transform.sparkfunction; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; +import org.apache.spark.sql.types.StructType; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.datavec.spark.transform.DataFrames; -import java.util.List; +import java.util.*; /** * Convert a record to a row * @author Adam Gibson */ -public class SequenceToRows extends BaseFlatMapFunctionAdaptee>, Row> { +public class SequenceToRows implements FlatMapFunction>, Row> { + + private Schema schema; + private StructType structType; public SequenceToRows(Schema schema) { - super(new SequenceToRowsAdapter(schema)); + this.schema = schema; + structType = DataFrames.fromSchemaSequence(schema); } + + @Override + public Iterator call(List> sequence) throws Exception { + if (sequence.size() == 0) + return Collections.emptyIterator(); + + String sequenceUUID = UUID.randomUUID().toString(); + + List out = new ArrayList<>(sequence.size()); + + int stepCount = 0; + for (List step : sequence) { + Object[] values = new Object[step.size() + 2]; + values[0] = sequenceUUID; + values[1] = stepCount++; + for (int i = 0; i < step.size(); i++) { + switch (schema.getColumnTypes().get(i)) { + case Double: + values[i + 2] = step.get(i).toDouble(); + break; + case Integer: + values[i + 2] = step.get(i).toInt(); + break; + case Long: + values[i + 2] = step.get(i).toLong(); + break; + case Float: + values[i + 2] = step.get(i).toFloat(); + break; + default: + throw new IllegalStateException( + "This api should not be used with strings , binary data or ndarrays. This is only for columnar data"); + } + } + + Row row = new GenericRowWithSchema(values, structType); + out.add(row); + } + + return out.iterator(); + } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java deleted file mode 100644 index 2ca2f32ae..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.sparkfunction; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; -import org.apache.spark.sql.types.StructType; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.DataFrames; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.UUID; - -/** - * Convert a record to a row - * @author Adam Gibson - */ -public class SequenceToRowsAdapter implements FlatMapFunctionAdapter>, Row> { - - private Schema schema; - private StructType structType; - - public SequenceToRowsAdapter(Schema schema) { - this.schema = schema; - structType = DataFrames.fromSchemaSequence(schema); - } - - - @Override - public Iterable call(List> sequence) throws Exception { - if (sequence.size() == 0) - return Collections.emptyList(); - - String sequenceUUID = UUID.randomUUID().toString(); - - List out = new ArrayList<>(sequence.size()); - - int stepCount = 0; - for (List step : sequence) { - Object[] values = new Object[step.size() + 2]; - values[0] = sequenceUUID; - values[1] = stepCount++; - for (int i = 0; i < step.size(); i++) { - switch (schema.getColumnTypes().get(i)) { - case Double: - values[i + 2] = step.get(i).toDouble(); - break; - case Integer: - values[i + 2] = step.get(i).toInt(); - break; - case Long: - values[i + 2] = step.get(i).toLong(); - break; - case Float: - values[i + 2] = step.get(i).toFloat(); - break; - default: - throw new IllegalStateException( - "This api should not be used with strings , binary data or ndarrays. This is only for columnar data"); - } - } - - Row row = new GenericRowWithSchema(values, structType); - out.add(row); - } - - return out; - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java index 41981a736..1a8782dfb 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java @@ -16,19 +16,27 @@ package org.datavec.spark.transform.transform; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.sequence.SequenceSplit; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Iterator; import java.util.List; /** * Created by Alex on 17/03/2016. */ -public class SequenceSplitFunction extends BaseFlatMapFunctionAdaptee>, List>> { +public class SequenceSplitFunction implements FlatMapFunction>, List>> { + + private final SequenceSplit split; public SequenceSplitFunction(SequenceSplit split) { - super(new SequenceSplitFunctionAdapter(split)); + this.split = split; + } + + @Override + public Iterator>> call(List> collections) throws Exception { + return split.split(collections).iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java deleted file mode 100644 index 5bde7ee62..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.transform; - -import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.List; - -/** - * Created by Alex on 17/03/2016. - */ -public class SequenceSplitFunctionAdapter - implements FlatMapFunctionAdapter>, List>> { - - private final SequenceSplit split; - - public SequenceSplitFunctionAdapter(SequenceSplit split) { - this.split = split; - } - - @Override - public Iterable>> call(List> collections) throws Exception { - return split.split(collections); - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java index ffe3f80c6..81f07b1f4 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java @@ -16,19 +16,32 @@ package org.datavec.spark.transform.transform; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.TransformProcess; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Collections; +import java.util.Iterator; import java.util.List; /** * Spark function for executing a transform process */ -public class SparkTransformProcessFunction extends BaseFlatMapFunctionAdaptee, List> { +public class SparkTransformProcessFunction implements FlatMapFunction, List> { + + private final TransformProcess transformProcess; public SparkTransformProcessFunction(TransformProcess transformProcess) { - super(new SparkTransformProcessFunctionAdapter(transformProcess)); + this.transformProcess = transformProcess; + } + + @Override + public Iterator> call(List v1) throws Exception { + List newList = transformProcess.execute(v1); + if (newList == null) + return Collections.emptyIterator(); //Example was filtered out + else + return Collections.singletonList(newList).iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java deleted file mode 100644 index 7b1766cc2..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.transform; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Collections; -import java.util.List; - -/** - * Spark function for executing a transform process - */ -public class SparkTransformProcessFunctionAdapter implements FlatMapFunctionAdapter, List> { - - private final TransformProcess transformProcess; - - public SparkTransformProcessFunctionAdapter(TransformProcess transformProcess) { - this.transformProcess = transformProcess; - } - - @Override - public Iterable> call(List v1) throws Exception { - List newList = transformProcess.execute(v1); - if (newList == null) - return Collections.emptyList(); //Example was filtered out - else - return Collections.singletonList(newList); - } -} diff --git a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java b/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java deleted file mode 100644 index af600a14d..000000000 --- a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.api.java.function.FlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -/** - * FlatMapFunction adapter to - * hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to FlatMapFunction - * - */ -public class BaseFlatMapFunctionAdaptee implements FlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterable call(K k) throws Exception { - return adapter.call(k); - } -} diff --git a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java b/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java deleted file mode 100644 index 0ad7c55bd..000000000 --- a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.sql.DataFrame; - -/** - * Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DataFrame / Dataset - * - */ -public class DataRowsFacade { - - private final DataFrame df; - - private DataRowsFacade(DataFrame df) { - this.df = df; - } - - public static DataRowsFacade dataRows(DataFrame df) { - return new DataRowsFacade(df); - } - - public DataFrame get() { - return df; - } -} diff --git a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java b/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java deleted file mode 100644 index f30e5a222..000000000 --- a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.api.java.function.FlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Iterator; - -/** - * FlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to FlatMapFunction - * - */ -public class BaseFlatMapFunctionAdaptee implements FlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterator call(K k) throws Exception { - return adapter.call(k).iterator(); - } -} diff --git a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java b/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java deleted file mode 100644 index 9958a622e..000000000 --- a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; - -/** - * Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DataFrame / Dataset - * - */ -public class DataRowsFacade { - - private final Dataset df; - - private DataRowsFacade(Dataset df) { - this.df = df; - } - - public static DataRowsFacade dataRows(Dataset df) { - return new DataRowsFacade(df); - } - - public Dataset get() { - return df; - } -} diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java index 49a3946e2..8f0247568 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java @@ -16,7 +16,7 @@ package org.datavec.spark.storage; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.datavec.api.writable.*; diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java index 5b6ff6342..a19725a2a 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java @@ -19,6 +19,8 @@ package org.datavec.spark.transform; import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.DoubleWritable; @@ -46,9 +48,9 @@ public class DataFramesTests extends BaseSparkTest { for (int i = 0; i < numColumns; i++) builder.addColumnDouble(String.valueOf(i)); Schema schema = builder.build(); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records)); - dataFrame.get().show(); - dataFrame.get().describe(DataFrames.toArray(schema.getColumnNames())).show(); + Dataset dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records)); + dataFrame.show(); + dataFrame.describe(DataFrames.toArray(schema.getColumnNames())).show(); // System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames())); // System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames())); @@ -77,12 +79,12 @@ public class DataFramesTests extends BaseSparkTest { assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema))); assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect()); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); - dataFrame.get().show(); + Dataset dataFrame = DataFrames.toDataFrame(schema, rdd); + dataFrame.show(); Column mean = DataFrames.mean(dataFrame, "0"); Column std = DataFrames.std(dataFrame, "0"); - dataFrame.get().withColumn("0", dataFrame.get().col("0").minus(mean)).show(); - dataFrame.get().withColumn("0", dataFrame.get().col("0").divide(std)).show(); + dataFrame.withColumn("0", dataFrame.col("0").minus(mean)).show(); + dataFrame.withColumn("0", dataFrame.col("0").divide(std)).show(); /* DataFrame desc = dataFrame.describe(dataFrame.columns()); dataFrame.show(); diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java index 3b1b1e1a6..5352ec10d 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java @@ -17,6 +17,7 @@ package org.datavec.spark.transform; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; @@ -24,11 +25,13 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; +import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; @@ -50,36 +53,35 @@ public class NormalizationTests extends BaseSparkTest { for (int i = 0; i < numColumns; i++) builder.addColumnDouble(String.valueOf(i)); + Nd4j.getRandom().setSeed(12345); + + INDArray arr = Nd4j.rand(DataType.FLOAT, 5, numColumns); for (int i = 0; i < 5; i++) { List record = new ArrayList<>(numColumns); data.add(record); for (int j = 0; j < numColumns; j++) { - record.add(new DoubleWritable(1.0)); + record.add(new DoubleWritable(arr.getDouble(i, j))); } - } - INDArray arr = RecordConverter.toMatrix(data); Schema schema = builder.build(); JavaRDD> rdd = sc.parallelize(data); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); + Dataset dataFrame = DataFrames.toDataFrame(schema, rdd); //assert equivalent to the ndarray pre processing - NormalizerStandardize standardScaler = new NormalizerStandardize(); - standardScaler.fit(new DataSet(arr.dup(), arr.dup())); - INDArray standardScalered = arr.dup(); - standardScaler.transform(new DataSet(standardScalered, standardScalered)); DataNormalization zeroToOne = new NormalizerMinMaxScaler(); zeroToOne.fit(new DataSet(arr.dup(), arr.dup())); INDArray zeroToOnes = arr.dup(); zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes)); - List rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.get().columns()); + List rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.columns()); INDArray assertion = DataFrames.toMatrix(rows); - //compare standard deviation - assertTrue(standardScaler.getStd().equalsWithEps(assertion.getRow(0), 1e-1)); + INDArray expStd = arr.std(true, true, 0); + INDArray std = assertion.getRow(0, true); + assertTrue(expStd.equalsWithEps(std, 1e-3)); //compare mean - assertTrue(standardScaler.getMean().equalsWithEps(assertion.getRow(1), 1e-1)); + INDArray expMean = arr.mean(true, 0); + assertTrue(expMean.equalsWithEps(assertion.getRow(1, true), 1e-3)); } @@ -109,10 +111,10 @@ public class NormalizationTests extends BaseSparkTest { assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema))); assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect()); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); - dataFrame.get().show(); - Normalization.zeromeanUnitVariance(dataFrame).get().show(); - Normalization.normalize(dataFrame).get().show(); + Dataset dataFrame = DataFrames.toDataFrame(schema, rdd); + dataFrame.show(); + Normalization.zeromeanUnitVariance(dataFrame).show(); + Normalization.normalize(dataFrame).show(); //assert equivalent to the ndarray pre processing NormalizerStandardize standardScaler = new NormalizerStandardize(); diff --git a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java index 7b0d9d3c1..04773ee74 100644 --- a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java +++ b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java @@ -52,18 +52,6 @@ public class DL4JSystemProperties { */ public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl"; - /** - * Applicability: deeplearning4j-nn
- * Description: Used for loading legacy format JSON containing custom layers. This system property is provided as an - * alternative to {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}. Classes are specified in - * comma-separated format.
- * This is required ONLY when ALL of the following conditions are met:
- * 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND
- * 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J), AND
- * 3. You haven't already called {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])} - */ - public static final String CUSTOM_REGISTRATION_PROPERTY = "org.deeplearning4j.config.custom.legacyclasses"; - /** * Applicability: deeplearning4j-nn
* Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index aeb5fe04b..81142fd68 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -96,18 +96,6 @@ test - - com.google.guava - guava - ${guava.version} - - - com.google.code.findbugs - jsr305 - - - - org.nd4j nd4j-api diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 56317b946..3bd1bd37f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.datasets.datavec; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 6ff639cb4..1e82a4783 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.datasets.datavec; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.commons.compress.utils.IOUtils; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 935d83218..52d3b0774 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -16,8 +16,8 @@ package org.deeplearning4j.nn.dtypes; -import com.google.common.collect.ImmutableSet; -import com.google.common.reflect.ClassPath; +import org.nd4j.shade.guava.collect.ImmutableSet; +import org.nd4j.shade.guava.reflect.ClassPath; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; @@ -103,7 +103,7 @@ public class DTypeTests extends BaseDL4JTest { ImmutableSet info; try { //Dependency note: this ClassPath class was added in Guava 14 - info = com.google.common.reflect.ClassPath.from(DTypeTests.class.getClassLoader()) + info = org.nd4j.shade.guava.reflect.ClassPath.from(DTypeTests.class.getClassLoader()) .getTopLevelClassesRecursive("org.deeplearning4j"); } catch (IOException e) { //Should never happen diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index fb5836a99..11e45c51d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -229,7 +229,9 @@ public class TestRnnLayers extends BaseDL4JTest { net.fit(in,l); } catch (Throwable t){ String msg = t.getMessage(); - assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); + if(msg == null) + t.printStackTrace(); + assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java index 77fc63aa0..aba60fe4f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.plot; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.IOUtils; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index f112d4386..a66914cd7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -22,23 +22,24 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer; -import org.junit.Before; import org.junit.Ignore; import org.junit.Test; -import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.regularization.WeightDecay; @@ -60,6 +61,9 @@ public class RegressionTest100a extends BaseDL4JTest { @Test public void testCustomLayer() throws Exception { + //We dropped support for 1.0.0-alpha and earlier custom layers due to the maintenance overhead for a rarely used feature + //An upgrade path exists as a workaround - load in beta to beta4 and re-save + //All built-in layers can be loaded going back to 0.5.0 File f = Resources.asFile("regression_testing/100a/CustomLayerExample_100a.bin"); @@ -68,67 +72,8 @@ public class RegressionTest100a extends BaseDL4JTest { fail("Expected exception"); } catch (Exception e){ String msg = e.getMessage(); - assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON")); + assertTrue(msg, msg.contains("custom") && msg.contains("1.0.0-beta") && msg.contains("saved again")); } - - NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class); - - MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - - DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer(); - assertEquals(new ActivationTanH(), l0.getActivationFn()); - assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0)); - assertEquals(new RmsProp(0.95), l0.getIUpdater()); - - CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer(); - assertEquals(new ActivationTanH(), l1.getActivationFn()); - assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); - assertEquals(new RmsProp(0.95), l1.getIUpdater()); - - - INDArray outExp; - File f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Output_100a.bin"); - try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){ - outExp = Nd4j.read(dis); - } - - INDArray in; - File f3 = Resources.asFile("regression_testing/100a/CustomLayerExample_Input_100a.bin"); - try(DataInputStream dis = new DataInputStream(new FileInputStream(f3))){ - in = Nd4j.read(dis); - } - - INDArray outAct = net.output(in); - - assertEquals(outExp, outAct); - - - //Check graph - f = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_100a.bin"); - - //Deregister custom class: - new LegacyLayerDeserializer().getLegacyNamesMap().remove("CustomLayer"); - - try { - ComputationGraph.load(f, true); - fail("Expected exception"); - } catch (Exception e){ - String msg = e.getMessage(); - assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON")); - } - - NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class); - - ComputationGraph graph = ComputationGraph.load(f, true); - - f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_Output_100a.bin"); - try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){ - outExp = Nd4j.read(dis); - } - - outAct = graph.outputSingle(in); - - assertEquals(outExp, outAct); } diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java index e7fb7246a..592aefcd6 100755 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java @@ -16,8 +16,8 @@ package org.deeplearning4j.datasets.iterator; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.annotations.VisibleForTesting; +import org.nd4j.shade.guava.collect.Lists; import lombok.Getter; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java index 955732c07..02f2c4eb0 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java @@ -16,7 +16,7 @@ package org.deeplearning4j.datasets.iterator.parallel; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java index dd83c1bd4..7eca0fac0 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java @@ -17,7 +17,7 @@ package org.deeplearning4j.plot; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java index 41b50795e..9efb88e24 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java @@ -16,7 +16,7 @@ package org.deeplearning4j.plot; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dimensionalityreduction.PCA; diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml index 59f05a6e8..dec29266f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/pom.xml +++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml @@ -37,6 +37,12 @@ ${nd4j.version} + + com.google.code.gson + gson + ${gson.version} + + org.deeplearning4j deeplearning4j-nn diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index 5351f0955..7477c7794 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -77,26 +77,11 @@ ${project.version} - - com.google.guava - guava - ${guava.version} - com.google.protobuf protobuf-java ${google.protobuf.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - joda-time joda-time @@ -213,11 +198,6 @@ play-netty-server_2.11 ${playframework.version} - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} -
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java deleted file mode 100644 index df178fd70..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nearestneighbor.server; - -import play.libs.F; -import play.mvc.Result; - -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * Utility methods for Routing - * - * @author Alex Black - */ -public class FunctionUtil { - - - public static F.Function0 function0(Supplier supplier) { - return supplier::get; - } - - public static F.Function function(Function function) { - return function::apply; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java index 58682d5e1..a79b57b19 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java @@ -33,8 +33,10 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.binary.BinarySerde; +import play.BuiltInComponents; import play.Mode; import play.libs.Json; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; @@ -149,19 +151,36 @@ public class NearestNeighborsServer { VPTree tree = new VPTree(points, similarityFunction, invert); - RoutingDsl routingDsl = new RoutingDsl(); + //Set play secret key, if required + //http://www.playframework.com/documentation/latest/ApplicationSecret + String crypto = System.getProperty("play.crypto.secret"); + if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { + byte[] newCrypto = new byte[1024]; + + new Random().nextBytes(newCrypto); + + String base64 = Base64.getEncoder().encodeToString(newCrypto); + System.setProperty("play.crypto.secret", base64); + } + + + server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b)); + } + + protected Router createRouter(VPTree tree, List labels, INDArray points, BuiltInComponents builtInComponents){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); //return the host information for a given id - routingDsl.POST("/knn").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/knn").routingTo(request -> { try { - NearestNeighborRequest record = Json.fromJson(request().body().asJson(), NearestNeighborRequest.class); + NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class); NearestNeighbor nearestNeighbor = - NearestNeighbor.builder().points(points).record(record).tree(tree).build(); + NearestNeighbor.builder().points(points).record(record).tree(tree).build(); if (record == null) return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); NearestNeighborsResults results = - NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); + NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); return ok(Json.toJson(results)); @@ -171,11 +190,11 @@ public class NearestNeighborsServer { e.printStackTrace(); return internalServerError(e.getMessage()); } - }))); + }); - routingDsl.POST("/knnnew").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/knnnew").routingTo(request -> { try { - Base64NDArrayBody record = Json.fromJson(request().body().asJson(), Base64NDArrayBody.class); + Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class); if (record == null) return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); @@ -216,23 +235,9 @@ public class NearestNeighborsServer { e.printStackTrace(); return internalServerError(e.getMessage()); } - }))); - - //Set play secret key, if required - //http://www.playframework.com/documentation/latest/ApplicationSecret - String crypto = System.getProperty("play.crypto.secret"); - if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { - byte[] newCrypto = new byte[1024]; - - new Random().nextBytes(newCrypto); - - String base64 = Base64.getEncoder().encodeToString(newCrypto); - System.setProperty("play.crypto.secret", base64); - } - - server = Server.forRouter(routingDsl.build(), Mode.PROD, port); - + }); + return routingDsl.build(); } /** diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml index 8b774cd56..f95f9268d 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml @@ -59,6 +59,13 @@ ${project.version} test + + + joda-time + joda-time + 2.10.3 + test + diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java index 1c57bc38a..cae103f10 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java @@ -16,8 +16,8 @@ package org.deeplearning4j.clustering.info; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Table; +import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Table; import org.deeplearning4j.clustering.cluster.Cluster; import org.deeplearning4j.clustering.cluster.ClusterSet; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java index c2154e6ba..f1cc2e304 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.quadtree; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java index 11746f4c2..5c31ee78a 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.randomprojection; -import com.google.common.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java index 83af0365a..659f334df 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.sptree; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.val; import org.deeplearning4j.clustering.algorithm.Distance; import org.deeplearning4j.nn.conf.WorkspaceMode; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java index 4b7b8e567..1de7a379b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.kdtree; -import com.google.common.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.deeplearning4j.clustering.BaseDL4JTest; import org.joda.time.Duration; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java index 03ad90748..f5ee19403 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.sptree; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import org.apache.commons.lang3.time.StopWatch; import org.deeplearning4j.clustering.BaseDL4JTest; import org.junit.Before; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java index 76658438a..5edb3926a 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java @@ -18,12 +18,10 @@ package org.deeplearning4j.clustering.vptree; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.clustering.BaseDL4JTest; import org.deeplearning4j.clustering.sptree.DataPoint; import org.joda.time.Duration; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,7 +31,6 @@ import org.nd4j.linalg.primitives.Counter; import org.nd4j.linalg.primitives.Pair; import java.util.*; -import java.util.concurrent.TimeUnit; import static org.junit.Assert.*; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java index 67169132c..03d2462a5 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java @@ -16,7 +16,7 @@ package org.deeplearning4j.text.corpora.sentiwordnet; -import com.google.common.collect.Sets; +import org.nd4j.shade.guava.collect.Sets; import org.apache.uima.analysis_engine.AnalysisEngine; import org.apache.uima.cas.CAS; import org.apache.uima.cas.CASException; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index d2d752509..f4dd1a6c5 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -16,8 +16,8 @@ package org.deeplearning4j.models; -import com.google.common.io.Files; -import com.google.common.primitives.Doubles; +import org.nd4j.shade.guava.io.Files; +import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.ArrayUtils; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 61e31b3c7..01b38a644 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -16,8 +16,8 @@ package org.deeplearning4j.models.word2vec; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Ints; import lombok.val; import net.didion.jwnl.data.Word; import org.apache.commons.io.FileUtils; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java index 5be56964c..579caa0a3 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.inmemory; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java index df502aded..84fc17b7e 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.reader.impl; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NonNull; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java index 75511cae1..f71c56717 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.wordvectors; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java index f97aee9c0..8680a809c 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.glove.count; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.nd4j.linalg.primitives.Pair; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java index 4c05b7cc7..64ee79dd4 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.paragraphvectors; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import com.google.gson.JsonObject; import com.google.gson.JsonParser; import lombok.Getter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java index 78a878930..87dd0880a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java @@ -16,8 +16,8 @@ package org.deeplearning4j.models.sequencevectors; -import com.google.common.primitives.Ints; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java index d5e789ebf..99263f6bc 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.sequencevectors.sequence; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java index cdb4b6c9e..5e51cef20 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java @@ -16,7 +16,7 @@ package org.deeplearning4j.text.invertedindex; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.nd4j.linalg.primitives.Pair; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java index 41536ff70..f5e0ca388 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.wordvectors; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java index 5834e1647..69767df13 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java @@ -16,8 +16,8 @@ package org.deeplearning4j.eval; -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; +import org.nd4j.shade.guava.collect.HashMultiset; +import org.nd4j.shade.guava.collect.Multiset; import lombok.Getter; import java.io.Serializable; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java index e964cd9b0..2b00ac375 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java @@ -16,7 +16,7 @@ package org.deeplearning4j.eval.curves; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java index b66230ddd..5d5e65c2a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java @@ -16,7 +16,7 @@ package org.deeplearning4j.eval.curves; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Data; import lombok.EqualsAndHashCode; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index 6cd8f06b3..5a5ce5665 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.OutputLayerUtil; @@ -40,6 +41,7 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -172,6 +174,26 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { ComputationGraphConfiguration conf; try { conf = mapper.readValue(json, ComputationGraphConfiguration.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, ComputationGraphConfiguration.class); + } catch (InvalidTypeIdException e2){ + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work + String msg = e2.getMessage(); + if(msg != null && msg.contains("Could not resolve type id")){ + throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + + " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e); + } + throw new RuntimeException(e2); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (Exception e) { //Check if this exception came from legacy deserializer... String msg = e.getMessage(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java index ab7fd044b..4f02e3d66 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java @@ -19,7 +19,6 @@ package org.deeplearning4j.nn.conf; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyPreprocessorDeserializerHelper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -34,8 +33,7 @@ import java.io.Serializable; * * @author Adam Gibson */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyPreprocessorDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface InputPreProcessor extends Serializable, Cloneable { /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java index de3373323..52a67ef16 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; @@ -41,6 +42,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossMSE; import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.databind.node.ArrayNode; import java.io.IOException; @@ -157,6 +159,26 @@ public class MultiLayerConfiguration implements Serializable, Cloneable { ObjectMapper mapper = NeuralNetConfiguration.mapper(); try { conf = mapper.readValue(json, MultiLayerConfiguration.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try { + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class); + } catch (InvalidTypeIdException e2){ + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work + String msg = e2.getMessage(); + if(msg != null && msg.contains("Could not resolve type id")){ + throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + + " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e); + } + throw new RuntimeException(e2); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (IOException e) { //Check if this exception came from legacy deserializer... String msg = e.getMessage(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index d8e77a55a..0da5bea13 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -26,20 +26,14 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; -import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyPreprocessorDeserializer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyReconstructionDistributionDeserializer; import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.nn.weights.IWeightInit; @@ -59,10 +53,6 @@ import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.L2Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.WeightDecay; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.primitives.Pair; -import org.nd4j.serde.json.LegacyIActivationDeserializer; -import org.nd4j.serde.json.LegacyILossFunctionDeserializer; import org.nd4j.shade.jackson.databind.ObjectMapper; import java.io.IOException; @@ -342,9 +332,7 @@ public class NeuralNetConfiguration implements Serializable, Cloneable { ObjectMapper mapper = mapper(); try { - String ret = mapper.writeValueAsString(this); - return ret; - + return mapper.writeValueAsString(this); } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) { throw new RuntimeException(e); } @@ -384,86 +372,6 @@ public class NeuralNetConfiguration implements Serializable, Cloneable { return JsonMappers.getMapper(); } - /** - * Set of classes that can be registered for legacy deserialization. - */ - private static List> REGISTERABLE_CUSTOM_CLASSES = (List>) Arrays.>asList( - Layer.class, - GraphVertex.class, - InputPreProcessor.class, - IActivation.class, - ILossFunction.class, - ReconstructionDistribution.class - ); - - /** - * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution - * ONLY) for JSON deserialization.
- *
- * This is required ONLY when BOTH of the following conditions are met:
- * 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND
- * 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J)
- *
- * By passing the classes of these layers here, DL4J should be able to deserialize them, in spite of the JSON - * format change between versions. - * - * @param classes Classes to register - */ - public static void registerLegacyCustomClassesForJSON(Class... classes) { - registerLegacyCustomClassesForJSONList(Arrays.>asList(classes)); - } - - /** - * @see #registerLegacyCustomClassesForJSON(Class[]) - */ - public static void registerLegacyCustomClassesForJSONList(List> classes){ - //Default names (i.e., old format for custom JSON format) - List> list = new ArrayList<>(); - for(Class c : classes){ - list.add(new Pair(c.getSimpleName(), c)); - } - registerLegacyCustomClassesForJSON(list); - } - - /** - * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution - * ONLY) for JSON deserialization, with custom names.
- * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])} - * but is added in case it is required in non-standard circumstances. - */ - public static void registerLegacyCustomClassesForJSON(List> classes){ - for(Pair p : classes){ - String s = p.getFirst(); - Class c = p.getRight(); - //Check if it's a valid class to register... - boolean found = false; - for( Class c2 : REGISTERABLE_CUSTOM_CLASSES){ - if(c2.isAssignableFrom(c)){ - if(c2 == Layer.class){ - LegacyLayerDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == GraphVertex.class){ - LegacyGraphVertexDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == InputPreProcessor.class){ - LegacyPreprocessorDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == IActivation.class ){ - LegacyIActivationDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == ILossFunction.class ){ - LegacyILossFunctionDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == ReconstructionDistribution.class){ - LegacyReconstructionDistributionDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } - - found = true; - } - } - - if(!found){ - throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " + - c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES); - } - } - } - /** * NeuralNetConfiguration builder, used as a starting point for creating a MultiLayerConfiguration or * ComputationGraphConfiguration.
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java index 2cd8d83f4..6ca3b35de 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java @@ -15,7 +15,7 @@ ******************************************************************************/ package org.deeplearning4j.nn.conf.graph; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.inputs.InputType; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java index e4968a49e..497cc77f5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java @@ -19,7 +19,6 @@ package org.deeplearning4j.nn.conf.graph; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializerHelper; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,8 +32,7 @@ import java.io.Serializable; * * @author Alex Black */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyGraphVertexDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public abstract class GraphVertex implements Cloneable, Serializable { @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index 805e1729c..85da86fa2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -16,12 +16,15 @@ package org.deeplearning4j.nn.conf.inputs; -import lombok.*; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonIgnore; import org.nd4j.shade.jackson.annotation.JsonInclude; -import org.nd4j.shade.jackson.annotation.JsonSubTypes; +import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.io.Serializable; @@ -36,12 +39,7 @@ import java.util.Arrays; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) -@JsonSubTypes(value = {@JsonSubTypes.Type(value = InputType.InputTypeFeedForward.class, name = "FeedForward"), - @JsonSubTypes.Type(value = InputType.InputTypeRecurrent.class, name = "Recurrent"), - @JsonSubTypes.Type(value = InputType.InputTypeConvolutional.class, name = "Convolutional"), - @JsonSubTypes.Type(value = InputType.InputTypeConvolutionalFlat.class, name = "ConvolutionalFlat"), - @JsonSubTypes.Type(value = InputType.InputTypeConvolutional3D.class, name = "Convolutional3D")}) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public abstract class InputType implements Serializable { /** @@ -174,13 +172,16 @@ public abstract class InputType implements Serializable { } - @AllArgsConstructor - @Getter @NoArgsConstructor + @Getter @EqualsAndHashCode(callSuper = false) public static class InputTypeFeedForward extends InputType { private long size; + public InputTypeFeedForward(@JsonProperty("size") long size) { + this.size = size; + } + @Override public Type getType() { return Type.FF; @@ -203,9 +204,8 @@ public abstract class InputType implements Serializable { } } - @Getter @NoArgsConstructor - @AllArgsConstructor + @Getter @EqualsAndHashCode(callSuper = false) public static class InputTypeRecurrent extends InputType { private long size; @@ -215,6 +215,11 @@ public abstract class InputType implements Serializable { this(size, -1); } + public InputTypeRecurrent(@JsonProperty("size") long size, @JsonProperty("timeSeriesLength") long timeSeriesLength) { + this.size = size; + this.timeSeriesLength = timeSeriesLength; + } + @Override public Type getType() { return Type.RNN; @@ -245,15 +250,19 @@ public abstract class InputType implements Serializable { } } - @AllArgsConstructor + @NoArgsConstructor @Data @EqualsAndHashCode(callSuper = false) - @NoArgsConstructor public static class InputTypeConvolutional extends InputType { private long height; private long width; private long channels; + public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) { + this.height = height; + this.width = width; + this.channels = channels; + } /** * Return the number of channels / depth for this 2D convolution. This method has been deprecated, @@ -298,10 +307,9 @@ public abstract class InputType implements Serializable { } } - @AllArgsConstructor + @NoArgsConstructor @Data @EqualsAndHashCode(callSuper = false) - @NoArgsConstructor public static class InputTypeConvolutional3D extends InputType { private Convolution3D.DataFormat dataFormat; private long depth; @@ -309,6 +317,15 @@ public abstract class InputType implements Serializable { private long width; private long channels; + public InputTypeConvolutional3D(@JsonProperty("dataFormat") Convolution3D.DataFormat dataFormat, + @JsonProperty("depth") long depth, @JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) { + this.dataFormat = dataFormat; + this.depth = depth; + this.height = height; + this.width = width; + this.channels = channels; + } + @Override public Type getType() { return Type.CNN3D; @@ -336,15 +353,20 @@ public abstract class InputType implements Serializable { } } - @AllArgsConstructor + @NoArgsConstructor @Data @EqualsAndHashCode(callSuper = false) - @NoArgsConstructor public static class InputTypeConvolutionalFlat extends InputType { private long height; private long width; private long depth; + public InputTypeConvolutionalFlat(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("depth") long depth) { + this.height = height; + this.width = width; + this.depth = depth; + } + @Override public Type getType() { return Type.CNNFlat; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java index 5dfb3b671..25577bd1f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java @@ -29,7 +29,6 @@ import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializerHelper; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -45,8 +44,7 @@ import java.util.*; * A neural network layer. */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyLayerDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @Data @NoArgsConstructor public abstract class Layer implements TrainingConfig, Serializable, Cloneable { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index 1e5eb11f2..0f1a770a8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -22,7 +22,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyIntArrayDeserializer; +import org.deeplearning4j.nn.conf.serde.legacy.LegacyIntArrayDeserializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.linalg.api.buffer.DataType; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java index 7b2317a8f..f72da09e5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java @@ -27,7 +27,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.conf.serde.FrozenLayerDeserializer; import org.deeplearning4j.nn.params.FrozenLayerParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -48,7 +47,6 @@ import java.util.List; * @author Alex Black */ @EqualsAndHashCode(callSuper = false) -@JsonDeserialize(using = FrozenLayerDeserializer.class) public class FrozenLayer extends Layer { @Getter diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java index 0de9143fe..f58e6b0e7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java @@ -16,7 +16,6 @@ package org.deeplearning4j.nn.conf.layers.variational; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyReconstructionDistributionDeserializerHelper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -32,8 +31,7 @@ import java.io.Serializable; * * @author Alex Black */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyReconstructionDistributionDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ReconstructionDistribution extends Serializable { /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java index 5194da4a1..d32488363 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java @@ -21,12 +21,18 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.weights.*; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.learning.config.*; import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.WeightDecay; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.impl.*; import org.nd4j.shade.jackson.core.JsonParser; import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.databind.DeserializationContext; @@ -38,6 +44,8 @@ import org.nd4j.shade.jackson.databind.node.ObjectNode; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; /** * A custom (abstract) deserializer that handles backward compatibility (currently only for updater refactoring that @@ -103,6 +111,24 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return false; } + protected boolean requiresActivationFromLegacy(Layer[] layers){ + for(Layer l : layers){ + if(l instanceof BaseLayer && ((BaseLayer)l).getActivationFn() == null){ + return true; + } + } + return false; + } + + protected boolean requiresLegacyLossHandling(Layer[] layers){ + for(Layer l : layers){ + if(l instanceof BaseOutputLayer && ((BaseOutputLayer)l).getLossFn() == null){ + return true; + } + } + return false; + } + protected void handleUpdaterBackwardCompatibility(BaseLayer layer, ObjectNode on){ if(on != null && on.has("updater")){ String updaterName = on.get("updater").asText(); @@ -220,7 +246,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im } protected void handleWeightInitBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){ - if(on != null && (on.has("weightInit") )){ + if(on != null && on.has("weightInit") ){ //Legacy format JSON if(on.has("weightInit")){ String wi = on.get("weightInit").asText(); @@ -228,8 +254,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im WeightInit w = WeightInit.valueOf(wi); Distribution d = null; if(w == WeightInit.DISTRIBUTION && on.has("dist")){ - //TODO deserialize distribution - String dist = on.get("dist").asText(); + String dist = on.get("dist").toString(); d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class); } IWeightInit iwi = w.getWeightInitFunction(d); @@ -241,6 +266,57 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im } } + //Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" : + protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){ + + if(baseLayer.getActivationFn() == null && on.has("activationFunction")){ + String afn = on.get("activationFunction").asText(); + IActivation a = null; + try { + a = getMap().get(afn.toLowerCase()).newInstance(); + } catch (InstantiationException | IllegalAccessException e){ + //Ignore + } + baseLayer.setActivationFn(a); + } + } + + //0.5.0 and earlier: loss function was an enum like "lossFunction" : "NEGATIVELOGLIKELIHOOD", + protected void handleLossBackwardCompatibility(BaseOutputLayer baseLayer, ObjectNode on){ + if(baseLayer.getLossFn() == null && on.has("activationFunction")) { + String lfn = on.get("lossFunction").asText(); + ILossFunction loss = null; + switch (lfn) { + case "MCXENT": + loss = new LossMCXENT(); + break; + case "MSE": + loss = new LossMSE(); + break; + case "NEGATIVELOGLIKELIHOOD": + loss = new LossNegativeLogLikelihood(); + break; + case "SQUARED_LOSS": + loss = new LossL2(); + break; + case "XENT": + loss = new LossBinaryXENT(); + } + baseLayer.setLossFn(loss); + } + } + + private static Map> activationMap; + private static synchronized Map> getMap(){ + if(activationMap == null){ + activationMap = new HashMap<>(); + for(Activation a : Activation.values()){ + activationMap.put(a.toString().toLowerCase(), a.getActivationFunction().getClass()); + } + } + return activationMap; + } + @Override public void resolve(DeserializationContext ctxt) throws JsonMappingException { ((ResolvableDeserializer) defaultDeserializer).resolve(ctxt); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java index 1c30b053b..50384e518 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; @@ -74,6 +75,8 @@ public class ComputationGraphConfigurationDeserializer boolean attemptIUpdaterFromLegacy = requiresIUpdaterFromLegacy(layers); boolean requireLegacyRegularizationHandling = requiresRegularizationFromLegacy(layers); boolean requiresLegacyWeightInitHandling = requiresWeightInitFromLegacy(layers); + boolean requiresLegacyActivationHandling = requiresActivationFromLegacy(layers); + boolean requiresLegacyLossHandling = requiresLegacyLossHandling(layers); Long charOffsetEnd = null; JsonLocation endLocation = null; @@ -123,6 +126,14 @@ public class ComputationGraphConfigurationDeserializer handleWeightInitBackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); } + if(requiresLegacyActivationHandling && layers[layerIdx] instanceof BaseLayer && ((BaseLayer)layers[layerIdx]).getActivationFn() == null){ + handleActivationBackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); + } + + if(requiresLegacyLossHandling && layers[layerIdx] instanceof BaseOutputLayer && ((BaseOutputLayer)layers[layerIdx]).getLossFn() == null){ + handleLossBackwardCompatibility((BaseOutputLayer) layers[layerIdx], (ObjectNode)next); + } + if(layers[layerIdx].getIDropout() == null){ //Check for legacy dropout if(next.has("dropOut")){ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/FrozenLayerDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/FrozenLayerDeserializer.java deleted file mode 100644 index 0eb618e20..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/FrozenLayerDeserializer.java +++ /dev/null @@ -1,58 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde; - -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer; -import org.nd4j.shade.jackson.core.JsonFactory; -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * A custom deserializer for handling Frozen layers - * This is used to handle the 2 different Layer JSON formats - old/legacy, and current - * - * @author Alex Black - */ -public class FrozenLayerDeserializer extends JsonDeserializer { - @Override - public Layer deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode n = jp.getCodec().readTree(jp); - JsonNode layer = n.get("layer"); - boolean newFormat = layer.has("@class"); - - String internalText = layer.toString(); - Layer internal; - if(newFormat){ - //Standard/new format - internal = NeuralNetConfiguration.mapper().readValue(internalText, Layer.class); - } else { - //Legacy format - JsonFactory factory = new JsonFactory(); - JsonParser parser = factory.createParser(internalText); - parser.setCodec(jp.getCodec()); - internal = new LegacyLayerDeserializer().deserialize(parser, deserializationContext); - } - return new FrozenLayer(internal); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java index a4e29c86e..1cdd85d4b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; +import org.deeplearning4j.nn.conf.serde.legacy.LegacyJsonFormat; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.serde.json.LegacyIActivationDeserializer; @@ -42,8 +43,10 @@ import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector; import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder; import org.nd4j.shade.jackson.databind.module.SimpleModule; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; +import org.nd4j.util.OneTimeLogger; import java.lang.annotation.Annotation; +import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -56,93 +59,14 @@ import java.util.List; @Slf4j public class JsonMappers { - /** - * @deprecated Use {@link DL4JSystemProperties#CUSTOM_REGISTRATION_PROPERTY} - */ - @Deprecated - public static String CUSTOM_REGISTRATION_PROPERTY = DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY; - - static { - String p = System.getProperty(DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY); - if(p != null && !p.isEmpty()){ - String[] split = p.split(","); - List> list = new ArrayList<>(); - for(String s : split){ - try{ - Class c = Class.forName(s); - list.add(c); - } catch (Throwable t){ - log.warn("Error parsing {} system property: class \"{}\" could not be loaded",DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY, s, t); - } - } - - if(list.size() > 0){ - try { - NeuralNetConfiguration.registerLegacyCustomClassesForJSONList(list); - } catch (Throwable t){ - log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY, t); - } - } - } - } - private static ObjectMapper jsonMapper = new ObjectMapper(); private static ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory()); - /* - Note to developers: The following JSON mappers are for handling legacy format JSON. - Note that after 1.0.0-alpha, the JSON subtype format for layers, preprocessors, graph vertices, - etc were changed from a wrapper object, to an "@class" field. - However, in an attempt to not break saved networks, these mappers are part of the solution. - - How legacy loading works (same pattern for all types - Layer, GraphVertex, InputPreprocesor etc) - 1. Layers etc that have an "@class" field are deserialized as normal - 2. Layers that don't have such a field are mapped (via Layer @JsonTypeInfo) to the LegacyLayerDeserializerHelper class. - 3. LegacyLayerDeserializerHelper has a @JsonDeserialize annotation - we use LegacyLayerDeserialize to handle it - 4. LegacyLayerDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names - 5. BaseLegacyDeserializer (that LegacyLayerDeserializer extends) does a lookup and handles the deserialization - - Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format, - as it'll fail due to not having the expected "@class" annotation. - Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified - class anyway. The ignoring is done via an annotation introspector, defined below in this class. - However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of - all types - if we did, then any nested types would fail (i.e., an IActivation in a Layer - the IActivation couldn't - be deserialized correctly, as the annotation would be ignored). - - */ - @Getter - private static ObjectMapper jsonMapperLegacyFormatLayer = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatVertex = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatPreproc = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatIActivation = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatILoss = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatReconstruction = new ObjectMapper(); + private static ObjectMapper legacyMapper; static { configureMapper(jsonMapper); configureMapper(yamlMapper); - configureMapper(jsonMapperLegacyFormatLayer); - configureMapper(jsonMapperLegacyFormatVertex); - configureMapper(jsonMapperLegacyFormatPreproc); - configureMapper(jsonMapperLegacyFormatIActivation); - configureMapper(jsonMapperLegacyFormatILoss); - configureMapper(jsonMapperLegacyFormatReconstruction); - - jsonMapperLegacyFormatLayer.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(Layer.class))); - jsonMapperLegacyFormatVertex.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(GraphVertex.class))); - jsonMapperLegacyFormatPreproc.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(InputPreProcessor.class))); - jsonMapperLegacyFormatIActivation.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(IActivation.class))); - jsonMapperLegacyFormatILoss.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(ILossFunction.class))); - jsonMapperLegacyFormatReconstruction.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(ReconstructionDistribution.class))); - - LegacyIActivationDeserializer.setLegacyJsonMapper(jsonMapperLegacyFormatIActivation); - LegacyILossFunctionDeserializer.setLegacyJsonMapper(jsonMapperLegacyFormatILoss); } /** @@ -152,6 +76,14 @@ public class JsonMappers { return jsonMapper; } + public static synchronized ObjectMapper getLegacyMapper(){ + if(legacyMapper == null){ + legacyMapper = LegacyJsonFormat.getMapper100alpha(); + configureMapper(legacyMapper); + } + return legacyMapper; + } + /** * @return The default/primary ObjectMapper for deserializing network configurations in DL4J (YAML format) */ @@ -182,60 +114,4 @@ public class JsonMappers { ret.registerModule(customDeserializerModule); } - - - /** - * Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc. - * This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring - * a set of JsonTypeInfo annotations - */ - @AllArgsConstructor - private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector { - - private List classList; - - @Override - protected TypeResolverBuilder _findTypeResolver(MapperConfig config, Annotated ann, JavaType baseType) { - if(ann instanceof AnnotatedClass){ - AnnotatedClass c = (AnnotatedClass)ann; - Class annClass = c.getAnnotated(); - - boolean isAssignable = false; - for(Class c2 : classList){ - if(c2.isAssignableFrom(annClass)){ - isAssignable = true; - break; - } - } - - if( isAssignable ){ - AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations(); - if(annotations == null || annotations.annotations() == null){ - //Probably not necessary - but here for safety - return super._findTypeResolver(config, ann, baseType); - } - - AnnotationMap newMap = null; - for(Annotation a : annotations.annotations()){ - Class annType = a.annotationType(); - if(annType == JsonTypeInfo.class){ - //Ignore the JsonTypeInfo annotation on the Layer class - continue; - } - if(newMap == null){ - newMap = new AnnotationMap(); - } - newMap.add(a); - } - if(newMap == null) - return null; - - //Pass the remaining annotations (if any) to the original introspector - AnnotatedClass ann2 = c.withAnnotations(newMap); - return super._findTypeResolver(config, ann2, baseType); - } - } - return super._findTypeResolver(config, ann, baseType); - } - } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java index e7bcdc39a..028fef9d3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; @@ -59,6 +60,8 @@ public class MultiLayerConfigurationDeserializer extends BaseNetConfigDeserializ boolean requiresLegacyRegularizationHandling = requiresRegularizationFromLegacy(layers); boolean requiresLegacyWeightInitHandling = requiresWeightInitFromLegacy(layers); + boolean requiresLegacyActivationHandling = requiresActivationFromLegacy(layers); + boolean requiresLegacyLossHandling = requiresLegacyLossHandling(layers); if(attemptIUpdaterFromLegacy || requiresLegacyRegularizationHandling || requiresLegacyWeightInitHandling) { JsonLocation endLocation = jp.getCurrentLocation(); @@ -115,38 +118,34 @@ public class MultiLayerConfigurationDeserializer extends BaseNetConfigDeserializ } } - if(requiresLegacyRegularizationHandling) { - if (layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getRegularization() == null) { - if(on.has("layer")){ - //Legacy format - ObjectNode layerNode = (ObjectNode)on.get("layer"); - if(layerNode.has("@class")){ - //Later legacy format: class field for JSON subclass - on = layerNode; - } else { - //Early legacy format: wrapper object for JSON subclass - on = (ObjectNode) on.get("layer").elements().next(); - } + if(requiresLegacyRegularizationHandling || requiresLegacyWeightInitHandling || requiresLegacyActivationHandling){ + if(on.has("layer")){ + //Legacy format + ObjectNode layerNode = (ObjectNode)on.get("layer"); + if(layerNode.has("@class")){ + //Later legacy format: class field for JSON subclass + on = layerNode; + } else { + //Early legacy format: wrapper object for JSON subclass + on = (ObjectNode) on.get("layer").elements().next(); } - handleL1L2BackwardCompatibility((BaseLayer) layers[i], on); } } - if(requiresLegacyWeightInitHandling){ - if (layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getWeightInitFn() == null) { - if(on.has("layer")){ - //Legacy format - ObjectNode layerNode = (ObjectNode)on.get("layer"); - if(layerNode.has("@class")){ - //Later legacy format: class field for JSON subclass - on = layerNode; - } else { - //Early legacy format: wrapper object for JSON subclass - on = (ObjectNode) on.get("layer").elements().next(); - } - } - handleWeightInitBackwardCompatibility((BaseLayer) layers[i], on); - } + if(requiresLegacyRegularizationHandling && layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getRegularization() == null) { + handleL1L2BackwardCompatibility((BaseLayer) layers[i], on); + } + + if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getWeightInitFn() == null) { + handleWeightInitBackwardCompatibility((BaseLayer) layers[i], on); + } + + if(requiresLegacyActivationHandling && layers[i] instanceof BaseLayer && ((BaseLayer)layers[i]).getActivationFn() == null){ + handleActivationBackwardCompatibility((BaseLayer) layers[i], on); + } + + if(requiresLegacyLossHandling && layers[i] instanceof BaseOutputLayer && ((BaseOutputLayer)layers[i]).getLossFn() == null){ + handleLossBackwardCompatibility((BaseOutputLayer) layers[i], on); } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyIntArrayDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java similarity index 97% rename from deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyIntArrayDeserializer.java rename to deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java index e080d6a78..064219fd1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyIntArrayDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.nn.conf.serde.legacyformat; +package org.deeplearning4j.nn.conf.serde.legacy; import org.nd4j.shade.jackson.core.JsonParser; import org.nd4j.shade.jackson.core.JsonProcessingException; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java new file mode 100644 index 000000000..e421c4b1f --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java @@ -0,0 +1,175 @@ +package org.deeplearning4j.nn.conf.serde.legacy; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.graph.*; +import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; +import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; +import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; +import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.conf.layers.util.MaskLayer; +import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.conf.layers.variational.*; +import org.deeplearning4j.nn.conf.preprocessor.*; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.impl.*; +import org.nd4j.shade.jackson.annotation.JsonSubTypes; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +/** + * This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override + * the existing annotations. + * In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field". + * We use these mixins to allow us to still load the old format + * + * @author Alex Black + */ +public class LegacyJsonFormat { + + private LegacyJsonFormat(){ } + + /** + * Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before + * @return Object mapper + */ + public static ObjectMapper getMapper100alpha(){ + //After 1.0.0-alpha, we switched from wrapper object to @class for subtype information + ObjectMapper om = new ObjectMapper(); + + om.addMixIn(InputPreProcessor.class, InputPreProcessorMixin.class); + om.addMixIn(GraphVertex.class, GraphVertexMixin.class); + om.addMixIn(Layer.class, LayerMixin.class); + om.addMixIn(ReconstructionDistribution.class, ReconstructionDistributionMixin.class); + om.addMixIn(IActivation.class, IActivationMixin.class); + om.addMixIn(ILossFunction.class, ILossFunctionMixin.class); + + return om; + } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = CnnToFeedForwardPreProcessor.class, name = "cnnToFeedForward"), + @JsonSubTypes.Type(value = CnnToRnnPreProcessor.class, name = "cnnToRnn"), + @JsonSubTypes.Type(value = ComposableInputPreProcessor.class, name = "composableInput"), + @JsonSubTypes.Type(value = FeedForwardToCnnPreProcessor.class, name = "feedForwardToCnn"), + @JsonSubTypes.Type(value = FeedForwardToRnnPreProcessor.class, name = "feedForwardToRnn"), + @JsonSubTypes.Type(value = RnnToFeedForwardPreProcessor.class, name = "rnnToFeedForward"), + @JsonSubTypes.Type(value = RnnToCnnPreProcessor.class, name = "rnnToCnn")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class InputPreProcessorMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ElementWiseVertex.class, name = "ElementWiseVertex"), + @JsonSubTypes.Type(value = MergeVertex.class, name = "MergeVertex"), + @JsonSubTypes.Type(value = SubsetVertex.class, name = "SubsetVertex"), + @JsonSubTypes.Type(value = LayerVertex.class, name = "LayerVertex"), + @JsonSubTypes.Type(value = LastTimeStepVertex.class, name = "LastTimeStepVertex"), + @JsonSubTypes.Type(value = ReverseTimeSeriesVertex.class, name = "ReverseTimeSeriesVertex"), + @JsonSubTypes.Type(value = DuplicateToTimeSeriesVertex.class, name = "DuplicateToTimeSeriesVertex"), + @JsonSubTypes.Type(value = PreprocessorVertex.class, name = "PreprocessorVertex"), + @JsonSubTypes.Type(value = StackVertex.class, name = "StackVertex"), + @JsonSubTypes.Type(value = UnstackVertex.class, name = "UnstackVertex"), + @JsonSubTypes.Type(value = L2Vertex.class, name = "L2Vertex"), + @JsonSubTypes.Type(value = ScaleVertex.class, name = "ScaleVertex"), + @JsonSubTypes.Type(value = L2NormalizeVertex.class, name = "L2NormalizeVertex")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class GraphVertexMixin{ } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = AutoEncoder.class, name = "autoEncoder"), + @JsonSubTypes.Type(value = ConvolutionLayer.class, name = "convolution"), + @JsonSubTypes.Type(value = Convolution1DLayer.class, name = "convolution1d"), + @JsonSubTypes.Type(value = GravesLSTM.class, name = "gravesLSTM"), + @JsonSubTypes.Type(value = LSTM.class, name = "LSTM"), + @JsonSubTypes.Type(value = GravesBidirectionalLSTM.class, name = "gravesBidirectionalLSTM"), + @JsonSubTypes.Type(value = OutputLayer.class, name = "output"), + @JsonSubTypes.Type(value = CenterLossOutputLayer.class, name = "CenterLossOutputLayer"), + @JsonSubTypes.Type(value = RnnOutputLayer.class, name = "rnnoutput"), + @JsonSubTypes.Type(value = LossLayer.class, name = "loss"), + @JsonSubTypes.Type(value = DenseLayer.class, name = "dense"), + @JsonSubTypes.Type(value = SubsamplingLayer.class, name = "subsampling"), + @JsonSubTypes.Type(value = Subsampling1DLayer.class, name = "subsampling1d"), + @JsonSubTypes.Type(value = BatchNormalization.class, name = "batchNormalization"), + @JsonSubTypes.Type(value = LocalResponseNormalization.class, name = "localResponseNormalization"), + @JsonSubTypes.Type(value = EmbeddingLayer.class, name = "embedding"), + @JsonSubTypes.Type(value = ActivationLayer.class, name = "activation"), + @JsonSubTypes.Type(value = VariationalAutoencoder.class, name = "VariationalAutoencoder"), + @JsonSubTypes.Type(value = DropoutLayer.class, name = "dropout"), + @JsonSubTypes.Type(value = GlobalPoolingLayer.class, name = "GlobalPooling"), + @JsonSubTypes.Type(value = ZeroPaddingLayer.class, name = "zeroPadding"), + @JsonSubTypes.Type(value = ZeroPadding1DLayer.class, name = "zeroPadding1d"), + @JsonSubTypes.Type(value = FrozenLayer.class, name = "FrozenLayer"), + @JsonSubTypes.Type(value = Upsampling2D.class, name = "Upsampling2D"), + @JsonSubTypes.Type(value = Yolo2OutputLayer.class, name = "Yolo2OutputLayer"), + @JsonSubTypes.Type(value = RnnLossLayer.class, name = "RnnLossLayer"), + @JsonSubTypes.Type(value = CnnLossLayer.class, name = "CnnLossLayer"), + @JsonSubTypes.Type(value = Bidirectional.class, name = "Bidirectional"), + @JsonSubTypes.Type(value = SimpleRnn.class, name = "SimpleRnn"), + @JsonSubTypes.Type(value = ElementWiseMultiplicationLayer.class, name = "ElementWiseMult"), + @JsonSubTypes.Type(value = MaskLayer.class, name = "MaskLayer"), + @JsonSubTypes.Type(value = MaskZeroLayer.class, name = "MaskZeroLayer"), + @JsonSubTypes.Type(value = Cropping1D.class, name = "Cropping1D"), + @JsonSubTypes.Type(value = Cropping2D.class, name = "Cropping2D")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class LayerMixin {} + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = GaussianReconstructionDistribution.class, name = "Gaussian"), + @JsonSubTypes.Type(value = BernoulliReconstructionDistribution.class, name = "Bernoulli"), + @JsonSubTypes.Type(value = ExponentialReconstructionDistribution.class, name = "Exponential"), + @JsonSubTypes.Type(value = CompositeReconstructionDistribution.class, name = "Composite"), + @JsonSubTypes.Type(value = LossFunctionWrapper.class, name = "LossWrapper")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ReconstructionDistributionMixin {} + + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ActivationCube.class, name = "Cube"), + @JsonSubTypes.Type(value = ActivationELU.class, name = "ELU"), + @JsonSubTypes.Type(value = ActivationHardSigmoid.class, name = "HardSigmoid"), + @JsonSubTypes.Type(value = ActivationHardTanH.class, name = "HardTanh"), + @JsonSubTypes.Type(value = ActivationIdentity.class, name = "Identity"), + @JsonSubTypes.Type(value = ActivationLReLU.class, name = "LReLU"), + @JsonSubTypes.Type(value = ActivationRationalTanh.class, name = "RationalTanh"), + @JsonSubTypes.Type(value = ActivationRectifiedTanh.class, name = "RectifiedTanh"), + @JsonSubTypes.Type(value = ActivationSELU.class, name = "SELU"), + @JsonSubTypes.Type(value = ActivationSwish.class, name = "SWISH"), + @JsonSubTypes.Type(value = ActivationReLU.class, name = "ReLU"), + @JsonSubTypes.Type(value = ActivationRReLU.class, name = "RReLU"), + @JsonSubTypes.Type(value = ActivationSigmoid.class, name = "Sigmoid"), + @JsonSubTypes.Type(value = ActivationSoftmax.class, name = "Softmax"), + @JsonSubTypes.Type(value = ActivationSoftPlus.class, name = "SoftPlus"), + @JsonSubTypes.Type(value = ActivationSoftSign.class, name = "SoftSign"), + @JsonSubTypes.Type(value = ActivationTanH.class, name = "TanH")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class IActivationMixin {} + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = LossBinaryXENT.class, name = "BinaryXENT"), + @JsonSubTypes.Type(value = LossCosineProximity.class, name = "CosineProximity"), + @JsonSubTypes.Type(value = LossHinge.class, name = "Hinge"), + @JsonSubTypes.Type(value = LossKLD.class, name = "KLD"), + @JsonSubTypes.Type(value = LossMAE.class, name = "MAE"), + @JsonSubTypes.Type(value = LossL1.class, name = "L1"), + @JsonSubTypes.Type(value = LossMAPE.class, name = "MAPE"), + @JsonSubTypes.Type(value = LossMCXENT.class, name = "MCXENT"), + @JsonSubTypes.Type(value = LossMSE.class, name = "MSE"), + @JsonSubTypes.Type(value = LossL2.class, name = "L2"), + @JsonSubTypes.Type(value = LossMSLE.class, name = "MSLE"), + @JsonSubTypes.Type(value = LossNegativeLogLikelihood.class, name = "NegativeLogLikelihood"), + @JsonSubTypes.Type(value = LossPoisson.class, name = "Poisson"), + @JsonSubTypes.Type(value = LossSquaredHinge.class, name = "SquaredHinge"), + @JsonSubTypes.Type(value = LossFMeasure.class, name = "FMeasure")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ILossFunctionMixin {} +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializer.java deleted file mode 100644 index 822d8fc3d..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializer.java +++ /dev/null @@ -1,94 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.graph.*; -import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; -import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; -import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * Deserializer for GraphVertex JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyGraphVertexDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - - - List> cList = Arrays.asList( - //All of these vertices had the legacy format name the same as the simple class name - MergeVertex.class, - SubsetVertex.class, - LayerVertex.class, - LastTimeStepVertex.class, - ReverseTimeSeriesVertex.class, - DuplicateToTimeSeriesVertex.class, - PreprocessorVertex.class, - StackVertex.class, - UnstackVertex.class, - L2Vertex.class, - ScaleVertex.class, - L2NormalizeVertex.class, - //These did not previously have a subtype annotation - they use default (which is simple class name) - ElementWiseVertex.class, - PoolHelperVertex.class, - ReshapeVertex.class, - ShiftVertex.class); - - for(Class c : cList){ - LEGACY_NAMES.put(c.getSimpleName(), c.getName()); - } - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { -// return JsonMappers.getMapperLegacyJson(); - return JsonMappers.getJsonMapperLegacyFormatVertex(); - } - - @Override - public Class getDeserializedType() { - return GraphVertex.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializerHelper.java deleted file mode 100644 index bf9a2654a..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyGraphVertexDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.graph.GraphVertex} - */ -@JsonDeserialize(using = LegacyGraphVertexDeserializer.class) -public class LegacyGraphVertexDeserializerHelper { - private LegacyGraphVertexDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializer.java deleted file mode 100644 index 9111dff87..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializer.java +++ /dev/null @@ -1,113 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; -import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; -import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; -import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; -import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; -import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; -import org.deeplearning4j.nn.conf.layers.util.MaskLayer; -import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; -import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.HashMap; -import java.util.Map; - -/** - * Deserializer for Layer JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyLayerDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - LEGACY_NAMES.put("autoEncoder", AutoEncoder.class.getName()); - LEGACY_NAMES.put("convolution", ConvolutionLayer.class.getName()); - LEGACY_NAMES.put("convolution1d", Convolution1DLayer.class.getName()); - LEGACY_NAMES.put("gravesLSTM", GravesLSTM.class.getName()); - LEGACY_NAMES.put("LSTM", LSTM.class.getName()); - LEGACY_NAMES.put("gravesBidirectionalLSTM", GravesBidirectionalLSTM.class.getName()); - LEGACY_NAMES.put("output", OutputLayer.class.getName()); - LEGACY_NAMES.put("CenterLossOutputLayer", CenterLossOutputLayer.class.getName()); - LEGACY_NAMES.put("rnnoutput", RnnOutputLayer.class.getName()); - LEGACY_NAMES.put("loss", LossLayer.class.getName()); - LEGACY_NAMES.put("dense", DenseLayer.class.getName()); - LEGACY_NAMES.put("subsampling", SubsamplingLayer.class.getName()); - LEGACY_NAMES.put("subsampling1d", Subsampling1DLayer.class.getName()); - LEGACY_NAMES.put("batchNormalization", BatchNormalization.class.getName()); - LEGACY_NAMES.put("localResponseNormalization", LocalResponseNormalization.class.getName()); - LEGACY_NAMES.put("embedding", EmbeddingLayer.class.getName()); - LEGACY_NAMES.put("activation", ActivationLayer.class.getName()); - LEGACY_NAMES.put("VariationalAutoencoder", VariationalAutoencoder.class.getName()); - LEGACY_NAMES.put("dropout", DropoutLayer.class.getName()); - LEGACY_NAMES.put("GlobalPooling", GlobalPoolingLayer.class.getName()); - LEGACY_NAMES.put("zeroPadding", ZeroPaddingLayer.class.getName()); - LEGACY_NAMES.put("zeroPadding1d", ZeroPadding1DLayer.class.getName()); - LEGACY_NAMES.put("FrozenLayer", FrozenLayer.class.getName()); - LEGACY_NAMES.put("Upsampling2D", Upsampling2D.class.getName()); - LEGACY_NAMES.put("Yolo2OutputLayer", Yolo2OutputLayer.class.getName()); - LEGACY_NAMES.put("RnnLossLayer", RnnLossLayer.class.getName()); - LEGACY_NAMES.put("CnnLossLayer", CnnLossLayer.class.getName()); - LEGACY_NAMES.put("Bidirectional", Bidirectional.class.getName()); - LEGACY_NAMES.put("SimpleRnn", SimpleRnn.class.getName()); - LEGACY_NAMES.put("ElementWiseMult", ElementWiseMultiplicationLayer.class.getName()); - LEGACY_NAMES.put("MaskLayer", MaskLayer.class.getName()); - LEGACY_NAMES.put("MaskZeroLayer", MaskZeroLayer.class.getName()); - LEGACY_NAMES.put("Cropping1D", Cropping1D.class.getName()); - LEGACY_NAMES.put("Cropping2D", Cropping2D.class.getName()); - - //The following didn't previously have subtype annotations - hence will be using default name (class simple name) - LEGACY_NAMES.put("LastTimeStep", LastTimeStep.class.getName()); - LEGACY_NAMES.put("SpaceToDepthLayer", SpaceToDepthLayer.class.getName()); - LEGACY_NAMES.put("SpaceToBatchLayer", SpaceToBatchLayer.class.getName()); - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { - return JsonMappers.getJsonMapperLegacyFormatLayer(); - } - - @Override - public Class getDeserializedType() { - return Layer.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializerHelper.java deleted file mode 100644 index 76f7f7d7d..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyLayerDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.layers.Layer} - */ -@JsonDeserialize(using = LegacyLayerDeserializer.class) -public class LegacyLayerDeserializerHelper { - private LegacyLayerDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializer.java deleted file mode 100644 index bed1c1c5c..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializer.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.graph.*; -import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; -import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; -import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; -import org.deeplearning4j.nn.conf.preprocessor.*; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * Deserializer for InputPreProcessor JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyPreprocessorDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - LEGACY_NAMES.put("cnnToFeedForward", CnnToFeedForwardPreProcessor.class.getName()); - LEGACY_NAMES.put("cnnToRnn", CnnToRnnPreProcessor.class.getName()); - LEGACY_NAMES.put("composableInput", ComposableInputPreProcessor.class.getName()); - LEGACY_NAMES.put("feedForwardToCnn", FeedForwardToCnnPreProcessor.class.getName()); - LEGACY_NAMES.put("feedForwardToRnn", FeedForwardToRnnPreProcessor.class.getName()); - LEGACY_NAMES.put("rnnToFeedForward", RnnToFeedForwardPreProcessor.class.getName()); - LEGACY_NAMES.put("rnnToCnn", RnnToCnnPreProcessor.class.getName()); - - //Keras preprocessors: they defaulted to class simple name - LEGACY_NAMES.put("KerasFlattenRnnPreprocessor","org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor"); - LEGACY_NAMES.put("ReshapePreprocessor","org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor"); - LEGACY_NAMES.put("TensorFlowCnnToFeedForwardPreProcessor","org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor"); - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { -// return JsonMappers.getMapperLegacyJson(); - return JsonMappers.getJsonMapperLegacyFormatPreproc(); - } - - @Override - public Class getDeserializedType() { - return InputPreProcessor.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializerHelper.java deleted file mode 100644 index 19300ba5f..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyPreprocessorDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.InputPreProcessor} - */ -@JsonDeserialize(using = LegacyPreprocessorDeserializer.class) -public class LegacyPreprocessorDeserializerHelper { - private LegacyPreprocessorDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializer.java deleted file mode 100644 index 06cf37f7e..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializer.java +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.layers.variational.*; -import org.deeplearning4j.nn.conf.preprocessor.*; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.HashMap; -import java.util.Map; - -/** - * Deserializer for ReconstructionDistribution JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyReconstructionDistributionDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - LEGACY_NAMES.put("Gaussian", GaussianReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("Bernoulli", BernoulliReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("Exponential", ExponentialReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("Composite", CompositeReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("LossWrapper", LossFunctionWrapper.class.getName()); - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { - return JsonMappers.getJsonMapperLegacyFormatReconstruction(); - } - - @Override - public Class getDeserializedType() { - return ReconstructionDistribution.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializerHelper.java deleted file mode 100644 index 61952d62e..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyReconstructionDistributionDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution} - */ -@JsonDeserialize(using = LegacyReconstructionDistributionDeserializer.class) -public class LegacyReconstructionDistributionDeserializerHelper { - private LegacyReconstructionDistributionDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java index e40fdcd32..5ce11cf5c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java @@ -16,7 +16,7 @@ package org.deeplearning4j.nn.workspace; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java index 8026d0764..b67961985 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java @@ -16,7 +16,7 @@ package org.deeplearning4j.optimize.listeners; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java index f474ee24b..9ff5993c3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java @@ -16,7 +16,7 @@ package org.deeplearning4j.optimize.listeners; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java index 42b63a849..dc632cecd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodingHandler.java @@ -16,7 +16,7 @@ package org.deeplearning4j.optimize.solvers.accumulation; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index f37ac245c..70945b06c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -16,7 +16,7 @@ package org.deeplearning4j.util; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml index 4a8b6f230..0b6b05c26 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml @@ -38,13 +38,6 @@ - - 2.1.0 - 2 - 2.11.12 2.11 @@ -102,11 +95,6 @@ jsch ${jsch.version} - - com.google.guava - guava - ${guava.version} - com.google.inject guice @@ -152,21 +140,6 @@ jaxb-impl ${jaxb.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-remote_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - io.netty netty diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml index 8fd0ea38f..eaf0e6be8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml @@ -72,27 +72,11 @@ test - - - io.netty - netty - ${netty.version} - org.scala-lang scala-library ${scala.version} - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} - - - com.typesafe - config - ${typesafe.config.version} - ch.qos.logback diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java index 1af2f4600..4ba5ba4ce 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java @@ -16,7 +16,7 @@ package org.deeplearning4j.parallelism; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingResult; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java index 4c9520ecd..432e22695 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java @@ -16,7 +16,7 @@ package org.deeplearning4j.parallelism.inference.observers; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 16e0176a8..3fded3e4a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -32,42 +32,6 @@ 3.4.2 - - - - - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.module - jackson-module-scala_${scala.binary.version} - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${spark2.jackson.version} - - - - org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java index 9164a41d0..d14668154 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.models.sequencevectors.functions; import lombok.NonNull; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; @@ -31,64 +30,56 @@ import org.nd4j.parameterserver.distributed.training.TrainingDriver; import org.nd4j.parameterserver.distributed.transport.RoutedTransport; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; /** * @author raver119@gmail.com */ -public class VocabRddFunctionFlat extends BaseFlatMapFunctionAdaptee, T> { +public class VocabRddFunctionFlat implements FlatMapFunction, T> { + protected Broadcast vectorsConfigurationBroadcast; + protected Broadcast paramServerConfigurationBroadcast; + + protected transient VectorsConfiguration configuration; + protected transient SparkElementsLearningAlgorithm ela; + protected transient TrainingDriver driver; + public VocabRddFunctionFlat(@NonNull Broadcast vectorsConfigurationBroadcast, - @NonNull Broadcast paramServerConfigurationBroadcast) { - super(new VocabRddFunctionAdapter(vectorsConfigurationBroadcast, paramServerConfigurationBroadcast)); + @NonNull Broadcast paramServerConfigurationBroadcast) { + this.vectorsConfigurationBroadcast = vectorsConfigurationBroadcast; + this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast; } + @Override + public Iterator call(Sequence sequence) throws Exception { + if (configuration == null) + configuration = vectorsConfigurationBroadcast.getValue(); - private static class VocabRddFunctionAdapter - implements FlatMapFunctionAdapter, T> { - protected Broadcast vectorsConfigurationBroadcast; - protected Broadcast paramServerConfigurationBroadcast; - - protected transient VectorsConfiguration configuration; - protected transient SparkElementsLearningAlgorithm ela; - protected transient TrainingDriver driver; - - public VocabRddFunctionAdapter(@NonNull Broadcast vectorsConfigurationBroadcast, - @NonNull Broadcast paramServerConfigurationBroadcast) { - this.vectorsConfigurationBroadcast = vectorsConfigurationBroadcast; - this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast; - } - - @Override - public Iterable call(Sequence sequence) throws Exception { - if (configuration == null) - configuration = vectorsConfigurationBroadcast.getValue(); - - if (ela == null) { - try { - ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()) - .newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } + if (ela == null) { + try { + ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()) + .newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); } - driver = ela.getTrainingDriver(); - - // we just silently initialize server - VoidParameterServer.getInstance().init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), - driver); - - // TODO: call for initializeSeqVec here - - List elements = new ArrayList<>(); - - elements.addAll(sequence.getElements()); - - // FIXME: this is PROBABLY bad, we might want to ensure, there's no duplicates. - if (configuration.isTrainSequenceVectors()) - if (!sequence.getSequenceLabels().isEmpty()) - elements.addAll(sequence.getSequenceLabels()); - - return elements; } + driver = ela.getTrainingDriver(); + + // we just silently initialize server + VoidParameterServer.getInstance().init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), + driver); + + // TODO: call for initializeSeqVec here + + List elements = new ArrayList<>(); + + elements.addAll(sequence.getElements()); + + // FIXME: this is PROBABLY bad, we might want to ensure, there's no duplicates. + if (configuration.isTrainSequenceVectors()) + if (!sequence.getSequenceLabels().isEmpty()) + elements.addAll(sequence.getSequenceLabels()); + + return elements.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java index 4718d0118..bbad9f2e3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java @@ -16,28 +16,252 @@ package org.deeplearning4j.spark.models.embeddings.word2vec; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; import scala.Tuple2; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; +import java.util.concurrent.atomic.AtomicLong; /** * @author jeffreytang * @author raver119@gmail.com */ -public class FirstIterationFunction extends - BaseFlatMapFunctionAdaptee, Long>>, Entry> { +public class FirstIterationFunction implements + FlatMapFunction, Long>>, Entry> { + + private int ithIteration = 1; + private int vectorLength; + private boolean useAdaGrad; + private int batchSize = 0; + private double negative; + private int window; + private double alpha; + private double minAlpha; + private long totalWordCount; + private long seed; + private int maxExp; + private double[] expTable; + private int iterations; + private Map indexSyn0VecMap; + private Map pointSyn1VecMap; + private AtomicLong nextRandom = new AtomicLong(5); + + private volatile VocabCache vocab; + private volatile NegativeHolder negativeHolder; + private AtomicLong cid = new AtomicLong(0); + private AtomicLong aff = new AtomicLong(0); + + public FirstIterationFunction(Broadcast> word2vecVarMapBroadcast, - Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { - super(new FirstIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast)); + Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { + + Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); + this.expTable = expTableBroadcast.getValue(); + this.vectorLength = (int) word2vecVarMap.get("vectorLength"); + this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); + this.negative = (double) word2vecVarMap.get("negative"); + this.window = (int) word2vecVarMap.get("window"); + this.alpha = (double) word2vecVarMap.get("alpha"); + this.minAlpha = (double) word2vecVarMap.get("minAlpha"); + this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); + this.seed = (long) word2vecVarMap.get("seed"); + this.maxExp = (int) word2vecVarMap.get("maxExp"); + this.iterations = (int) word2vecVarMap.get("iterations"); + this.batchSize = (int) word2vecVarMap.get("batchSize"); + this.indexSyn0VecMap = new HashMap<>(); + this.pointSyn1VecMap = new HashMap<>(); + this.vocab = vocabCacheBroadcast.getValue(); + + if (this.vocab == null) + throw new RuntimeException("VocabCache is null"); + + if (negative > 0) { + negativeHolder = NegativeHolder.getInstance(); + negativeHolder.initHolder(vocab, expTable, this.vectorLength); + } + } + + + + @Override + public Iterator> call(Iterator, Long>> pairIter) { + while (pairIter.hasNext()) { + List, Long>> batch = new ArrayList<>(); + while (pairIter.hasNext() && batch.size() < batchSize) { + Tuple2, Long> pair = pairIter.next(); + List vocabWordsList = pair._1(); + Long sentenceCumSumCount = pair._2(); + batch.add(Pair.of(vocabWordsList, sentenceCumSumCount)); + } + + for (int i = 0; i < iterations; i++) { + //System.out.println("Training sentence: " + vocabWordsList); + for (Pair, Long> pair : batch) { + List vocabWordsList = pair.getKey(); + Long sentenceCumSumCount = pair.getValue(); + double currentSentenceAlpha = Math.max(minAlpha, + alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount)); + trainSentence(vocabWordsList, currentSentenceAlpha); + } + } + } + return indexSyn0VecMap.entrySet().iterator(); + } + + + public void trainSentence(List vocabWordsList, double currentSentenceAlpha) { + + if (vocabWordsList != null && !vocabWordsList.isEmpty()) { + for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ithWordInSentence++) { + // Random value ranging from 0 to window size + nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); + int b = (int) (long) this.nextRandom.get() % window; + VocabWord currentWord = vocabWordsList.get(ithWordInSentence); + if (currentWord != null) { + skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha); + } + } + } + } + + public void skipGram(int ithWordInSentence, List vocabWordsList, int b, double currentSentenceAlpha) { + + VocabWord currentWord = vocabWordsList.get(ithWordInSentence); + if (currentWord != null && !vocabWordsList.isEmpty()) { + int end = window * 2 + 1 - b; + for (int a = b; a < end; a++) { + if (a != window) { + int c = ithWordInSentence - window + a; + if (c >= 0 && c < vocabWordsList.size()) { + VocabWord lastWord = vocabWordsList.get(c); + iterateSample(currentWord, lastWord, currentSentenceAlpha); + } + } + } + } + } + + public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { + + + if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) + return; + final int currentWordIndex = w2.getIndex(); + + // error for current word and context + INDArray neu1e = Nd4j.create(vectorLength); + + // First iteration Syn0 is random numbers + INDArray l1 = null; + if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { + l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); + } else { + l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); + } + + // + for (int i = 0; i < w1.getCodeLength(); i++) { + int code = w1.getCodes().get(i); + int point = w1.getPoints().get(i); + if (point < 0) + throw new IllegalStateException("Illegal point " + point); + // Point to + INDArray syn1; + if (pointSyn1VecMap.containsKey(point)) { + syn1 = pointSyn1VecMap.get(point); + } else { + syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros + pointSyn1VecMap.put(point, syn1); + } + + // Dot product of Syn0 and Syn1 vecs + double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); + + if (dot < -maxExp || dot >= maxExp) + continue; + + int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); + + if (idx >= expTable.length) + continue; + + //score + double f = expTable[idx]; + //gradient + double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) + : currentSentenceAlpha); + + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); + } + + int target = w1.getIndex(); + int label; + //negative sampling + if (negative > 0) + for (int d = 0; d < negative + 1; d++) { + if (d == 0) + label = 1; + else { + nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); + + // FIXME: int cast + int idx = Math.abs((int) (nextRandom.get() >> 16) % (int) negativeHolder.getTable().length()); + + target = negativeHolder.getTable().getInt(idx); + if (target <= 0) + target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; + + if (target == w1.getIndex()) + continue; + label = 0; + } + + if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) + continue; + + double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target)); + double g; + if (f > maxExp) + g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; + else if (f < -maxExp) + g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); + else { + int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); + if (idx >= expTable.length) + continue; + + g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) + : (label - expTable[idx]) * alpha; + } + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, negativeHolder.getSyn1Neg().slice(target), neu1e); + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, negativeHolder.getSyn1Neg().slice(target)); + } + + + // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. + Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); + + VocabWord word = vocab.elementAtIndex(currentWordIndex); + indexSyn0VecMap.put(word, l1); + } + + private INDArray getRandomSyn0Vec(int vectorLength, long lseed) { + /* + we use wordIndex as part of seed here, to guarantee that during word syn0 initialization on dwo distinct nodes, initial weights will be the same for the same word + */ + return Nd4j.rand( new int[] {1, vectorLength}, lseed * seed).subi(0.5).divi(vectorLength); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java deleted file mode 100644 index 05ca83fc6..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java +++ /dev/null @@ -1,265 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.word2vec; - -import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; -import scala.Tuple2; - -import java.util.*; -import java.util.concurrent.atomic.AtomicLong; - -/** - * @author jeffreytang - * @author raver119@gmail.com - */ -public class FirstIterationFunctionAdapter implements - FlatMapFunctionAdapter, Long>>, Map.Entry> { - - private int ithIteration = 1; - private int vectorLength; - private boolean useAdaGrad; - private int batchSize = 0; - private double negative; - private int window; - private double alpha; - private double minAlpha; - private long totalWordCount; - private long seed; - private int maxExp; - private double[] expTable; - private int iterations; - private Map indexSyn0VecMap; - private Map pointSyn1VecMap; - private AtomicLong nextRandom = new AtomicLong(5); - - private volatile VocabCache vocab; - private volatile NegativeHolder negativeHolder; - private AtomicLong cid = new AtomicLong(0); - private AtomicLong aff = new AtomicLong(0); - - - - public FirstIterationFunctionAdapter(Broadcast> word2vecVarMapBroadcast, - Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { - - Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); - this.expTable = expTableBroadcast.getValue(); - this.vectorLength = (int) word2vecVarMap.get("vectorLength"); - this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); - this.negative = (double) word2vecVarMap.get("negative"); - this.window = (int) word2vecVarMap.get("window"); - this.alpha = (double) word2vecVarMap.get("alpha"); - this.minAlpha = (double) word2vecVarMap.get("minAlpha"); - this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); - this.seed = (long) word2vecVarMap.get("seed"); - this.maxExp = (int) word2vecVarMap.get("maxExp"); - this.iterations = (int) word2vecVarMap.get("iterations"); - this.batchSize = (int) word2vecVarMap.get("batchSize"); - this.indexSyn0VecMap = new HashMap<>(); - this.pointSyn1VecMap = new HashMap<>(); - this.vocab = vocabCacheBroadcast.getValue(); - - if (this.vocab == null) - throw new RuntimeException("VocabCache is null"); - - if (negative > 0) { - negativeHolder = NegativeHolder.getInstance(); - negativeHolder.initHolder(vocab, expTable, this.vectorLength); - } - } - - - - @Override - public Iterable> call(Iterator, Long>> pairIter) { - while (pairIter.hasNext()) { - List, Long>> batch = new ArrayList<>(); - while (pairIter.hasNext() && batch.size() < batchSize) { - Tuple2, Long> pair = pairIter.next(); - List vocabWordsList = pair._1(); - Long sentenceCumSumCount = pair._2(); - batch.add(Pair.of(vocabWordsList, sentenceCumSumCount)); - } - - for (int i = 0; i < iterations; i++) { - //System.out.println("Training sentence: " + vocabWordsList); - for (Pair, Long> pair : batch) { - List vocabWordsList = pair.getKey(); - Long sentenceCumSumCount = pair.getValue(); - double currentSentenceAlpha = Math.max(minAlpha, - alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount)); - trainSentence(vocabWordsList, currentSentenceAlpha); - } - } - } - return indexSyn0VecMap.entrySet(); - } - - - public void trainSentence(List vocabWordsList, double currentSentenceAlpha) { - - if (vocabWordsList != null && !vocabWordsList.isEmpty()) { - for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ithWordInSentence++) { - // Random value ranging from 0 to window size - nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - int b = (int) (long) this.nextRandom.get() % window; - VocabWord currentWord = vocabWordsList.get(ithWordInSentence); - if (currentWord != null) { - skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha); - } - } - } - } - - public void skipGram(int ithWordInSentence, List vocabWordsList, int b, double currentSentenceAlpha) { - - VocabWord currentWord = vocabWordsList.get(ithWordInSentence); - if (currentWord != null && !vocabWordsList.isEmpty()) { - int end = window * 2 + 1 - b; - for (int a = b; a < end; a++) { - if (a != window) { - int c = ithWordInSentence - window + a; - if (c >= 0 && c < vocabWordsList.size()) { - VocabWord lastWord = vocabWordsList.get(c); - iterateSample(currentWord, lastWord, currentSentenceAlpha); - } - } - } - } - } - - public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { - - - if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) - return; - final int currentWordIndex = w2.getIndex(); - - // error for current word and context - INDArray neu1e = Nd4j.create(vectorLength); - - // First iteration Syn0 is random numbers - INDArray l1 = null; - if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { - l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); - } else { - l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); - } - - // - for (int i = 0; i < w1.getCodeLength(); i++) { - int code = w1.getCodes().get(i); - int point = w1.getPoints().get(i); - if (point < 0) - throw new IllegalStateException("Illegal point " + point); - // Point to - INDArray syn1; - if (pointSyn1VecMap.containsKey(point)) { - syn1 = pointSyn1VecMap.get(point); - } else { - syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros - pointSyn1VecMap.put(point, syn1); - } - - // Dot product of Syn0 and Syn1 vecs - double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); - - if (dot < -maxExp || dot >= maxExp) - continue; - - int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); - - if (idx >= expTable.length) - continue; - - //score - double f = expTable[idx]; - //gradient - double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) - : currentSentenceAlpha); - - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); - } - - int target = w1.getIndex(); - int label; - //negative sampling - if (negative > 0) - for (int d = 0; d < negative + 1; d++) { - if (d == 0) - label = 1; - else { - nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - - // FIXME: int cast - int idx = Math.abs((int) (nextRandom.get() >> 16) % (int) negativeHolder.getTable().length()); - - target = negativeHolder.getTable().getInt(idx); - if (target <= 0) - target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; - - if (target == w1.getIndex()) - continue; - label = 0; - } - - if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) - continue; - - double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target)); - double g; - if (f > maxExp) - g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; - else if (f < -maxExp) - g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); - else { - int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); - if (idx >= expTable.length) - continue; - - g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) - : (label - expTable[idx]) * alpha; - } - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, negativeHolder.getSyn1Neg().slice(target), neu1e); - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, negativeHolder.getSyn1Neg().slice(target)); - } - - - // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. - Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); - - VocabWord word = vocab.elementAtIndex(currentWordIndex); - indexSyn0VecMap.put(word, l1); - } - - private INDArray getRandomSyn0Vec(int vectorLength, long lseed) { - /* - we use wordIndex as part of seed here, to guarantee that during word syn0 initialization on dwo distinct nodes, initial weights will be the same for the same word - */ - return Nd4j.rand( new int[] {1, vectorLength}, lseed * seed).subi(0.5).divi(vectorLength); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java index 50ec16ff9..c34156484 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.models.embeddings.word2vec; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.nd4j.linalg.api.ndarray.INDArray; @@ -37,22 +36,7 @@ import java.util.concurrent.atomic.AtomicLong; * @author jeffreytang * @author raver119@gmail.com */ -public class SecondIterationFunction extends - BaseFlatMapFunctionAdaptee, Long>>, Entry> { - - public SecondIterationFunction(Broadcast> word2vecVarMapBroadcast, - Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { - super(new SecondIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast)); - } -} - - -/** - * @author jeffreytang - * @author raver119@gmail.com - */ -class SecondIterationFunctionAdapter - implements FlatMapFunctionAdapter, Long>>, Entry> { +public class SecondIterationFunction implements FlatMapFunction, Long>>, Entry> { private int ithIteration = 1; private int vectorLength; @@ -78,7 +62,7 @@ class SecondIterationFunctionAdapter - public SecondIterationFunctionAdapter(Broadcast> word2vecVarMapBroadcast, + public SecondIterationFunction(Broadcast> word2vecVarMapBroadcast, Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); @@ -110,7 +94,7 @@ class SecondIterationFunctionAdapter @Override - public Iterable> call(Iterator, Long>> pairIter) { + public Iterator> call(Iterator, Long>> pairIter) { this.vocabHolder = VocabHolder.getInstance(); this.vocabHolder.setSeed(seed, vectorLength); @@ -139,7 +123,7 @@ class SecondIterationFunctionAdapter } } } - return vocabHolder.getSplit(vocab); + return vocabHolder.getSplit(vocab).iterator(); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index 0d82bf48c..d8f425286 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -32,42 +32,6 @@ UTF-8 - - - - - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.module - jackson-module-scala_${scala.binary.version} - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${spark2.jackson.version} - - - - org.nd4j diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java index 135cfa1c4..900f0e63b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.parameterserver.functions; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper; @@ -32,28 +31,20 @@ import java.util.Iterator; * @author raver119@gmail.com */ -public class SharedFlatMapDataSet extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapDataSet(TrainingWorker worker) { - super(new SharedFlatMapDataSetAdapter(worker)); - } -} - - -class SharedFlatMapDataSetAdapter implements FlatMapFunctionAdapter, R> { +public class SharedFlatMapDataSet implements FlatMapFunction, R> { private final SharedTrainingWorker worker; - public SharedFlatMapDataSetAdapter(TrainingWorker worker) { + public SharedFlatMapDataSet(TrainingWorker worker) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } /* @@ -70,6 +61,6 @@ class SharedFlatMapDataSetAdapter implements FlatMapFu // all threads in this executor will be blocked here until training finished SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java index d7de24f15..5ce338b0f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.parameterserver.functions; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper; @@ -31,30 +30,20 @@ import java.util.Iterator; /** * Created by raver119 on 13.06.17. */ -public class SharedFlatMapMultiDataSet - extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapMultiDataSet(TrainingWorker worker) { - super(new SharedFlatMapMultiDataSetAdapter(worker)); - } -} - - -class SharedFlatMapMultiDataSetAdapter - implements FlatMapFunctionAdapter, R> { +public class SharedFlatMapMultiDataSet implements FlatMapFunction, R> { private final SharedTrainingWorker worker; - public SharedFlatMapMultiDataSetAdapter(TrainingWorker worker) { + public SharedFlatMapMultiDataSet(TrainingWorker worker) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } /* That's the place where we do our stuff. Here's the plan: @@ -70,6 +59,6 @@ class SharedFlatMapMultiDataSetAdapter // all threads in this executor will be blocked here until training finished SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java index 5028e077f..4c8192ae7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java @@ -18,9 +18,8 @@ package org.deeplearning4j.spark.parameterserver.functions; import org.apache.commons.io.LineIterator; import org.apache.hadoop.conf.Configuration; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.DataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -39,11 +38,7 @@ import java.util.Iterator; * * @author raver119@gmail.com */ -public class SharedFlatMapPaths extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapPaths(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { - super(new SharedFlatMapPathsAdapter(worker, loader, hadoopConfig)); - } +public class SharedFlatMapPaths implements FlatMapFunction, R> { public static File toTempFile(Iterator dataSetIterator) throws IOException { File f = Files.createTempFile("SharedFlatMapPaths",".txt").toFile(); @@ -56,17 +51,14 @@ public class SharedFlatMapPaths extends BaseFlatMapFun } return f; } -} - -class SharedFlatMapPathsAdapter implements FlatMapFunctionAdapter, R> { public static Configuration defaultConfig; protected final SharedTrainingWorker worker; protected final DataSetLoader loader; protected final Broadcast hadoopConfig; - public SharedFlatMapPathsAdapter(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { + public SharedFlatMapPaths(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; this.loader = loader; @@ -74,10 +66,10 @@ class SharedFlatMapPathsAdapter implements FlatMapFunc } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } // here we'll be converting out Strings coming out of iterator to DataSets // PathSparkDataSetIterator does that for us @@ -93,7 +85,7 @@ class SharedFlatMapPathsAdapter implements FlatMapFunc // first callee will become master, others will obey and die SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } finally { lineIter.close(); f.delete(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java index a8fbadb2b..3a8b1c213 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.parameterserver.functions; import org.apache.commons.io.LineIterator; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.MultiDataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -37,21 +36,13 @@ import java.util.Iterator; /** * @author raver119@gmail.com */ -public class SharedFlatMapPathsMDS extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapPathsMDS(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { - super(new SharedFlatMapPathsMDSAdapter(worker, loader, hadoopConfig)); - } -} - - -class SharedFlatMapPathsMDSAdapter implements FlatMapFunctionAdapter, R> { +public class SharedFlatMapPathsMDS implements FlatMapFunction, R> { protected final SharedTrainingWorker worker; protected final MultiDataSetLoader loader; protected final Broadcast hadoopConfig; - public SharedFlatMapPathsMDSAdapter(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { + public SharedFlatMapPathsMDS(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; this.loader = loader; @@ -59,10 +50,10 @@ class SharedFlatMapPathsMDSAdapter implements FlatMapF } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } // here we'll be converting out Strings coming out of iterator to DataSets // PathSparkDataSetIterator does that for us @@ -78,7 +69,7 @@ class SharedFlatMapPathsMDSAdapter implements FlatMapF // first callee will become master, others will obey and die SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } finally { lineIter.close(); f.delete(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java index e2c371c06..c3910968e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.api.worker; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -41,30 +40,16 @@ import java.util.Iterator; * * @author Alex Black */ -public class ExecuteWorkerFlatMap extends BaseFlatMapFunctionAdaptee, R> { - - public ExecuteWorkerFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerFlatMapAdapter(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on DataSets. - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -class ExecuteWorkerFlatMapAdapter implements FlatMapFunctionAdapter, R> { +public class ExecuteWorkerFlatMap implements FlatMapFunction, R> { private final TrainingWorker worker; - public ExecuteWorkerFlatMapAdapter(TrainingWorker worker) { + public ExecuteWorkerFlatMap(TrainingWorker worker) { this.worker = worker; } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { WorkerConfiguration dataConfig = worker.getDataConfiguration(); final boolean isGraph = dataConfig.isGraphNetwork(); @@ -79,9 +64,9 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu Pair pair = worker.getFinalResultNoDataWithStats(); pair.getFirst().setStats(s.build(pair.getSecond())); - return Collections.singletonList(pair.getFirst()); + return Collections.singletonList(pair.getFirst()).iterator(); } else { - return Collections.singletonList(worker.getFinalResultNoData()); + return Collections.singletonList(worker.getFinalResultNoData()).iterator(); } } @@ -131,7 +116,7 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu SparkTrainingStats returnStats = s.build(workerStats); result.getFirst().setStats(returnStats); - return Collections.singletonList(result.getFirst()); + return Collections.singletonList(result.getFirst()).iterator(); } } else { R result; @@ -141,7 +126,7 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu result = worker.processMinibatch(next, net, !batchedIterator.hasNext()); if (result != null) { //Terminate training immediately - return Collections.singletonList(result); + return Collections.singletonList(result).iterator(); } } } @@ -155,12 +140,12 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu else pair = worker.getFinalResultWithStats(net); pair.getFirst().setStats(s.build(pair.getSecond())); - return Collections.singletonList(pair.getFirst()); + return Collections.singletonList(pair.getFirst()).iterator(); } else { if (isGraph) - return Collections.singletonList(worker.getFinalResult(graph)); + return Collections.singletonList(worker.getFinalResult(graph)).iterator(); else - return Collections.singletonList(worker.getFinalResult(net)); + return Collections.singletonList(worker.getFinalResult(net)).iterator(); } } finally { //Make sure we shut down the async thread properly... diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java index 15fcd6b89..6fa148394 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java @@ -16,8 +16,8 @@ package org.deeplearning4j.spark.api.worker; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import lombok.AllArgsConstructor; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -39,31 +39,13 @@ import java.util.Iterator; * * @author Alex Black */ -public class ExecuteWorkerMultiDataSetFlatMap - extends BaseFlatMapFunctionAdaptee, R> { - - public ExecuteWorkerMultiDataSetFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on MultiDataSets. Used only in SparkComputationGraph implementation. - * - * @author Alex Black - */ -class ExecuteWorkerMultiDataSetFlatMapAdapter - implements FlatMapFunctionAdapter, R> { +@AllArgsConstructor +public class ExecuteWorkerMultiDataSetFlatMap implements FlatMapFunction, R> { private final TrainingWorker worker; - public ExecuteWorkerMultiDataSetFlatMapAdapter(TrainingWorker worker) { - this.worker = worker; - } - @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { WorkerConfiguration dataConfig = worker.getDataConfiguration(); boolean stats = dataConfig.isCollectTrainingStats(); @@ -75,7 +57,7 @@ class ExecuteWorkerMultiDataSetFlatMapAdapter if (stats) s.logReturnTime(); //TODO return the results... - return Collections.emptyList(); //Sometimes: no data + return Collections.emptyIterator(); //Sometimes: no data } int batchSize = dataConfig.getBatchSizePerWorker(); @@ -118,13 +100,13 @@ class ExecuteWorkerMultiDataSetFlatMapAdapter SparkTrainingStats returnStats = s.build(workerStats); result.getFirst().setStats(returnStats); - return Collections.singletonList(result.getFirst()); + return Collections.singletonList(result.getFirst()).iterator(); } } else { R result = worker.processMinibatch(next, net, !batchedIterator.hasNext()); if (result != null) { //Terminate training immediately - return Collections.singletonList(result); + return Collections.singletonList(result).iterator(); } } } @@ -134,9 +116,9 @@ class ExecuteWorkerMultiDataSetFlatMapAdapter s.logReturnTime(); Pair pair = worker.getFinalResultWithStats(net); pair.getFirst().setStats(s.build(pair.getSecond())); - return Collections.singletonList(pair.getFirst()); + return Collections.singletonList(pair.getFirst()).iterator(); } else { - return Collections.singletonList(worker.getFinalResult(net)); + return Collections.singletonList(worker.getFinalResult(net)).iterator(); } } finally { Nd4j.getExecutioner().commit(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java index 2e1bdb646..4969a055b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.input.PortableDataStream; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator; @@ -33,32 +32,15 @@ import java.util.Iterator; * @author Alex Black */ @Deprecated -public class ExecuteWorkerPDSFlatMap - extends BaseFlatMapFunctionAdaptee, R> { +public class ExecuteWorkerPDSFlatMap implements FlatMapFunction, R> { + private final FlatMapFunction, R> workerFlatMap; public ExecuteWorkerPDSFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerPDSFlatMapAdapter<>(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded using a PortableDataStream - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -@Deprecated -class ExecuteWorkerPDSFlatMapAdapter - implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; - - public ExecuteWorkerPDSFlatMapAdapter(TrainingWorker worker) { - this.workerFlatMap = new ExecuteWorkerFlatMapAdapter<>(worker); + this.workerFlatMap = new ExecuteWorkerFlatMap<>(worker); } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { return workerFlatMap.call(new PortableDataStreamDataSetIterator(iter)); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java index ebc4d4691..63b82bdaa 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.input.PortableDataStream; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.iterator.PortableDataStreamMultiDataSetIterator; @@ -33,32 +32,15 @@ import java.util.Iterator; * @author Alex Black */ @Deprecated -public class ExecuteWorkerPDSMDSFlatMap - extends BaseFlatMapFunctionAdaptee, R> { +public class ExecuteWorkerPDSMDSFlatMap implements FlatMapFunction, R> { + private final FlatMapFunction, R> workerFlatMap; public ExecuteWorkerPDSMDSFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerPDSMDSFlatMapAdapter<>(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized MultiDataSet objects, that can be loaded using a PortableDataStream - * Used for SparkComputationGraph implementations only - * - * @author Alex Black - */ -@Deprecated -class ExecuteWorkerPDSMDSFlatMapAdapter - implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; - - public ExecuteWorkerPDSMDSFlatMapAdapter(TrainingWorker worker) { - this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker); + this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMap<>(worker); } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { return workerFlatMap.call(new PortableDataStreamMultiDataSetIterator(iter)); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java index 5e3889a99..e26615ab8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.DataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -38,30 +37,15 @@ import java.util.List; * * @author Alex Black */ -public class ExecuteWorkerPathFlatMap - extends BaseFlatMapFunctionAdaptee, R> { +public class ExecuteWorkerPathFlatMap implements FlatMapFunction, R> { - public ExecuteWorkerPathFlatMap(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { - super(new ExecuteWorkerPathFlatMapAdapter<>(worker, loader, hadoopConfig)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded from a path (local or HDFS) - * that is specified as a String - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -class ExecuteWorkerPathFlatMapAdapter implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; + private final FlatMapFunction, R> workerFlatMap; private final DataSetLoader dataSetLoader; private final int maxDataSetObjects; private final Broadcast hadoopConfig; - public ExecuteWorkerPathFlatMapAdapter(TrainingWorker worker, DataSetLoader dataSetLoader, Broadcast hadoopConfig) { - this.workerFlatMap = new ExecuteWorkerFlatMapAdapter<>(worker); + public ExecuteWorkerPathFlatMap(TrainingWorker worker, DataSetLoader dataSetLoader, Broadcast hadoopConfig) { + this.workerFlatMap = new ExecuteWorkerFlatMap<>(worker); this.dataSetLoader = dataSetLoader; this.hadoopConfig = hadoopConfig; @@ -84,7 +68,7 @@ class ExecuteWorkerPathFlatMapAdapter implements FlatM } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { List list = new ArrayList<>(); int count = 0; while (iter.hasNext() && count++ < maxDataSetObjects) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java index 072425f18..47e53bd6c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.MultiDataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -38,31 +37,14 @@ import java.util.List; * * @author Alex Black */ -public class ExecuteWorkerPathMDSFlatMap - extends BaseFlatMapFunctionAdaptee, R> { - - public ExecuteWorkerPathMDSFlatMap(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { - super(new ExecuteWorkerPathMDSFlatMapAdapter<>(worker, loader, hadoopConfig)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded from a path (local or HDFS) - * that is specified as a String - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -class ExecuteWorkerPathMDSFlatMapAdapter - implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; +public class ExecuteWorkerPathMDSFlatMap implements FlatMapFunction, R> { + private final FlatMapFunction, R> workerFlatMap; private MultiDataSetLoader loader; private final int maxDataSetObjects; private final Broadcast hadoopConfig; - public ExecuteWorkerPathMDSFlatMapAdapter(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { - this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker); + public ExecuteWorkerPathMDSFlatMap(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { + this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMap<>(worker); this.loader = loader; this.hadoopConfig = hadoopConfig; @@ -85,7 +67,7 @@ class ExecuteWorkerPathMDSFlatMapAdapter } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { List list = new ArrayList<>(); int count = 0; while (iter.hasNext() && count++ < maxDataSetObjects) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java index fd3a5a2fc..9513193e4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java @@ -16,8 +16,8 @@ package org.deeplearning4j.spark.data; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import lombok.AllArgsConstructor; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.DataSet; import java.util.ArrayList; @@ -38,37 +38,12 @@ import java.util.List; * * @author Alex Black */ -public class BatchDataSetsFunction extends BaseFlatMapFunctionAdaptee, DataSet> { - - public BatchDataSetsFunction(int minibatchSize) { - super(new BatchDataSetsFunctionAdapter(minibatchSize)); - } -} - - -/** - * Function used to batch DataSet objects together. Typically used to combine singe-example DataSet objects out of - * something like {@link org.deeplearning4j.spark.datavec.DataVecDataSetFunction} together into minibatches.
- * - * Usage: - *
- * {@code
- *      RDD mySingleExampleDataSets = ...;
- *      RDD batchData = mySingleExampleDataSets.mapPartitions(new BatchDataSetsFunction(batchSize));
- * }
- * 
- * - * @author Alex Black - */ -class BatchDataSetsFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { +@AllArgsConstructor +public class BatchDataSetsFunction implements FlatMapFunction, DataSet> { private final int minibatchSize; - public BatchDataSetsFunctionAdapter(int minibatchSize) { - this.minibatchSize = minibatchSize; - } - @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { List out = new ArrayList<>(); while (iter.hasNext()) { List list = new ArrayList<>(); @@ -88,6 +63,6 @@ class BatchDataSetsFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { - - public SplitDataSetsFunction() { - super(new SplitDataSetsFunctionAdapter()); - } -} - - -/** - * Take an existing DataSet object, and split it into multiple DataSet objects with one example in each - * - * Usage: - *
- * {@code
- *      RDD myBatchedExampleDataSets = ...;
- *      RDD singleExamlpeDataSets = myBatchedExampleDataSets.mapPartitions(new SplitDataSets(batchSize));
- * }
- * 
- * - * @author Alex Black - */ -class SplitDataSetsFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { +public class SplitDataSetsFunction implements FlatMapFunction, DataSet> { @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { List out = new ArrayList<>(); while (dataSetIterator.hasNext()) { out.addAll(dataSetIterator.next().asList()); } - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java index e0557d10c..1ccf54b91 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java @@ -17,12 +17,13 @@ package org.deeplearning4j.spark.data.shuffle; import org.apache.spark.api.java.JavaRDD; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.nd4j.linalg.dataset.DataSet; import scala.Tuple2; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Random; @@ -34,34 +35,17 @@ import java.util.Random; * * @author Alex Black */ -public class SplitDataSetExamplesPairFlatMapFunction extends BasePairFlatMapFunctionAdaptee { - - public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex) { - super(new SplitDataSetExamplesPairFlatMapFunctionAdapter(maxKeyIndex)); - } -} - - -/** - * A PairFlatMapFunction that splits each example in a {@link DataSet} object into its own {@link DataSet}. - * Also adds a random key (integer value) in the range 0 to maxKeyIndex-1.
- * - * Used in {@link org.deeplearning4j.spark.util.SparkUtils#shuffleExamples(JavaRDD, int, int)} - * - * @author Alex Black - */ -class SplitDataSetExamplesPairFlatMapFunctionAdapter - implements FlatMapFunctionAdapter> { +public class SplitDataSetExamplesPairFlatMapFunction implements PairFlatMapFunction { private transient Random r; private int maxKeyIndex; - public SplitDataSetExamplesPairFlatMapFunctionAdapter(int maxKeyIndex) { + public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex) { this.maxKeyIndex = maxKeyIndex; } @Override - public Iterable> call(DataSet dataSet) throws Exception { + public Iterator> call(DataSet dataSet) throws Exception { if (r == null) { r = new Random(); } @@ -72,6 +56,6 @@ class SplitDataSetExamplesPairFlatMapFunctionAdapter out.add(new Tuple2<>(r.nextInt(maxKeyIndex), ds)); } - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java index 37ef1e7ab..653bbc75d 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java @@ -16,9 +16,9 @@ package org.deeplearning4j.spark.datavec; +import lombok.AllArgsConstructor; import org.apache.spark.api.java.JavaRDD; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.DataSet; import java.io.Serializable; @@ -44,22 +44,12 @@ public class RDDMiniBatches implements Serializable { return toSplitJava.mapPartitions(new MiniBatchFunction(miniBatches)); } - public static class MiniBatchFunction extends BaseFlatMapFunctionAdaptee, DataSet> { - - public MiniBatchFunction(int batchSize) { - super(new MiniBatchFunctionAdapter(batchSize)); - } - } - - static class MiniBatchFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { - private int batchSize = 10; - - public MiniBatchFunctionAdapter(int batchSize) { - this.batchSize = batchSize; - } + @AllArgsConstructor + public static class MiniBatchFunction implements FlatMapFunction, DataSet> { + private int batchSize; @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { List ret = new ArrayList<>(); List temp = new ArrayList<>(); while (dataSetIterator.hasNext()) { @@ -74,10 +64,7 @@ public class RDDMiniBatches implements Serializable { if (temp.size() > 0) ret.add(DataSet.merge(temp)); - return ret; + return ret.iterator(); } - } - - } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java index 6d717371a..9412a3cbb 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java @@ -16,8 +16,8 @@ package org.deeplearning4j.spark.impl.common.repartition; -import com.google.common.base.Predicate; -import com.google.common.collect.Collections2; +import org.nd4j.shade.guava.base.Predicate; +import org.nd4j.shade.guava.collect.Collections2; import org.apache.spark.Partitioner; import scala.Tuple2; @@ -25,8 +25,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; +import static org.nd4j.shade.guava.base.Preconditions.checkArgument; +import static org.nd4j.shade.guava.base.Preconditions.checkNotNull; /** * This is a custom partitioner that rebalances a minimum of elements @@ -97,12 +97,13 @@ public class HashingBalancedPartitioner extends Partitioner { @Override public int numPartitions() { - return Collections2.filter(partitionWeightsByClass.get(0), new Predicate() { - @Override - public boolean apply(Double aDouble) { - return aDouble >= 0; - } - }).size(); + List list = partitionWeightsByClass.get(0); + int count = 0; + for(Double d : list){ + if(d >= 0) + count++; + } + return count; } @Override diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java index b540cff82..23b702444 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.impl.common.repartition; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.PairFlatMapFunction; import scala.Tuple2; import java.util.ArrayList; @@ -30,22 +29,14 @@ import java.util.List; * * @author Alex Black */ -public class MapTupleToPairFlatMap extends BasePairFlatMapFunctionAdaptee>, T, U> { - - public MapTupleToPairFlatMap() { - super(new MapTupleToPairFlatMapAdapter()); - } -} - - -class MapTupleToPairFlatMapAdapter implements FlatMapFunctionAdapter>, Tuple2> { +public class MapTupleToPairFlatMap implements PairFlatMapFunction>, T, U> { @Override - public Iterable> call(Iterator> tuple2Iterator) throws Exception { + public Iterator> call(Iterator> tuple2Iterator) throws Exception { List> list = new ArrayList<>(); while (tuple2Iterator.hasNext()) { list.add(tuple2Iterator.next()); } - return list; + return list.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java similarity index 87% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java rename to deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java index 58f092de8..99fcdc65b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java @@ -27,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * @param Type of key, associated with each example. Used to keep track of which score belongs to which example * @author Alex Black */ -public abstract class BaseVaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeScoreWithKeyFunctionAdapter { +public abstract class BaseVaeReconstructionProbWithKeyFunction extends BaseVaeScoreWithKeyFunction { private final boolean useLogProbability; private final int numSamples; @@ -39,8 +39,8 @@ public abstract class BaseVaeReconstructionProbWithKeyFunctionAdapter extends * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} */ - public BaseVaeReconstructionProbWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, - boolean useLogProbability, int batchSize, int numSamples) { + public BaseVaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, + boolean useLogProbability, int batchSize, int numSamples) { super(params, jsonConfig, batchSize); this.useLogProbability = useLogProbability; this.numSamples = numSamples; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java similarity index 88% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java rename to deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java index 19b4e1b14..da6a374c4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java @@ -18,8 +18,9 @@ package org.deeplearning4j.spark.impl.common.score; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -38,8 +39,7 @@ import java.util.List; * @author Alex Black */ @Slf4j -public abstract class BaseVaeScoreWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { +public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunction>, K, Double> { protected final Broadcast params; protected final Broadcast jsonConfig; @@ -51,7 +51,7 @@ public abstract class BaseVaeScoreWithKeyFunctionAdapter * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use when scoring */ - public BaseVaeScoreWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { + public BaseVaeScoreWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.batchSize = batchSize; @@ -63,9 +63,9 @@ public abstract class BaseVaeScoreWithKeyFunctionAdapter @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } VariationalAutoencoder vae = getVaeLayer(); @@ -108,6 +108,6 @@ public abstract class BaseVaeScoreWithKeyFunctionAdapter log.debug("Scored {} examples ", totalCount); } - return ret; + return ret.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java index dd02c5d37..cdb41ba33 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java @@ -17,12 +17,10 @@ package org.deeplearning4j.spark.impl.graph.evaluation; import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.impl.evaluation.EvaluationRunner; import org.nd4j.evaluation.IEvaluation; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import java.util.Collections; @@ -36,26 +34,7 @@ import java.util.concurrent.Future; * * @author Alex Black */ -public class IEvaluateMDSFlatMapFunction - extends BaseFlatMapFunctionAdaptee, T[]> { - - public IEvaluateMDSFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, - T... evaluations) { - super(new IEvaluateMDSFlatMapFunctionAdapter<>(json, params, evalNumWorkers, evalBatchSize, evaluations)); - } -} - - -/** - * Function to evaluate data (using an IEvaluation instance), in a distributed manner - * Flat map function used to batch examples for computational efficiency + reduce number of IEvaluation objects returned - * for network efficiency. - * - * @author Alex Black - */ -@Slf4j -class IEvaluateMDSFlatMapFunctionAdapter - implements FlatMapFunctionAdapter, T[]> { +public class IEvaluateMDSFlatMapFunction implements FlatMapFunction, T[]> { protected Broadcast json; protected Broadcast params; @@ -70,7 +49,7 @@ class IEvaluateMDSFlatMapFunctionAdapter * this. Used to avoid doing too many at once (and hence memory issues) * @param evaluations Initial evaulation instance (i.e., empty Evaluation or RegressionEvaluation instance) */ - public IEvaluateMDSFlatMapFunctionAdapter(Broadcast json, Broadcast params, int evalNumWorkers, + public IEvaluateMDSFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, T[] evaluations) { this.json = json; this.params = params; @@ -80,13 +59,13 @@ class IEvaluateMDSFlatMapFunctionAdapter } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } if (!dataSetIterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } Future f = EvaluationRunner.getInstance().execute( @@ -94,9 +73,9 @@ class IEvaluateMDSFlatMapFunctionAdapter IEvaluation[] result = f.get(); if(result == null){ - return Collections.emptyList(); + return Collections.emptyIterator(); } else { - return Collections.singletonList((T[])result); + return Collections.singletonList((T[])result).iterator(); } } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java index cea6a7ab0..d520605a3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.impl.graph.evaluation; import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.DataSetLoader; import org.deeplearning4j.api.loader.MultiDataSetLoader; @@ -43,26 +42,7 @@ import java.util.concurrent.Future; * * @author Alex Black */ -public class IEvaluateMDSPathsFlatMapFunction - extends BaseFlatMapFunctionAdaptee, IEvaluation[]> { - - public IEvaluateMDSPathsFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, - DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, - Broadcast configuration, IEvaluation... evaluations) { - super(new IEvaluateMDSPathsFlatMapFunctionAdapter(json, params, evalNumWorkers, evalBatchSize, dsLoader, mdsLoader, configuration, evaluations)); - } -} - - -/** - * Function to evaluate data (using an IEvaluation instance), in a distributed manner - * Flat map function used to batch examples for computational efficiency + reduce number of IEvaluation objects returned - * for network efficiency. - * - * @author Alex Black - */ -@Slf4j -class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter, IEvaluation[]> { +public class IEvaluateMDSPathsFlatMapFunction implements FlatMapFunction, IEvaluation[]> { protected Broadcast json; protected Broadcast params; @@ -80,7 +60,7 @@ class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter< * this. Used to avoid doing too many at once (and hence memory issues) * @param evaluations Initial evaulation instance (i.e., empty Evaluation or RegressionEvaluation instance) */ - public IEvaluateMDSPathsFlatMapFunctionAdapter(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, + public IEvaluateMDSPathsFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, Broadcast configuration, IEvaluation[] evaluations) { this.json = json; this.params = params; @@ -93,9 +73,9 @@ class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter< } @Override - public Iterable call(Iterator paths) throws Exception { + public Iterator call(Iterator paths) throws Exception { if (!paths.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiDataSetIterator iter; @@ -109,9 +89,9 @@ class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter< Future f = EvaluationRunner.getInstance().execute(evaluations, evalNumWorkers, evalBatchSize, null, iter, true, json, params); IEvaluation[] result = f.get(); if(result == null){ - return Collections.emptyList(); + return Collections.emptyIterator(); } else { - return Collections.singletonList(result); + return Collections.singletonList(result).iterator(); } } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java index 107b22284..0e5f01343 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java @@ -21,7 +21,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunctionAdapter; +import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -33,7 +33,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * @author Alex Black * @see CGVaeReconstructionProbWithKeyFunction */ -public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunctionAdapter { +public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction { /** diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java index 4db797bb2..835bb8fa7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java @@ -21,7 +21,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunctionAdapter; +import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -31,7 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * * @author Alex Black */ -public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunctionAdapter { +public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunction { /** diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java index 3f84a6fb2..cec2f5b17 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java @@ -16,11 +16,13 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -40,48 +42,19 @@ import java.util.List; * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example * @author Alex Black */ -public class GraphFeedForwardWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>, K, INDArray[]> { - - public GraphFeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { - super(new GraphFeedForwardWithKeyFunctionAdapter(params, jsonConfig, batchSize)); - } -} - - -/** - * Function to feed-forward examples, and get the network output (for example, class probabilities). - * A key value is used to keey track of which output corresponds to which input. - * - * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example - * @author Alex Black - */ -class GraphFeedForwardWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { - - protected static Logger log = LoggerFactory.getLogger(GraphFeedForwardWithKeyFunction.class); +@Slf4j +@AllArgsConstructor +public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction>, K, INDArray[]> { private final Broadcast params; private final Broadcast jsonConfig; private final int batchSize; - /** - * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json - * @param batchSize Batch size to use for forward pass (use > 1 for efficiency) - */ - public GraphFeedForwardWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, - int batchSize) { - this.params = params; - this.jsonConfig = jsonConfig; - this.batchSize = batchSize; - } - @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); @@ -129,7 +102,7 @@ class GraphFeedForwardWithKeyFunctionAdapter } if (tupleCount == 0) { - return Collections.emptyList(); + return Collections.emptyIterator(); } List> output = new ArrayList<>(tupleCount); @@ -198,7 +171,7 @@ class GraphFeedForwardWithKeyFunctionAdapter Nd4j.getExecutioner().commit(); - return output; + return output.iterator(); } private INDArray getSubset(int exampleStart, int exampleEnd, INDArray from) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java index ee21d483e..44474248d 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java @@ -16,11 +16,12 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.DoubleFlatMapFunction; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.spark.util.BaseDoubleFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; @@ -41,31 +42,15 @@ import java.util.List; * @author Alex Black * @see ScoreExamplesWithKeyFunction */ -public class ScoreExamplesFunction extends BaseDoubleFlatMapFunctionAdaptee> { - - public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/**Function to score examples individually. Note that scoring is batched for computational efficiency.
- * This is essentially a Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
- * Note: This method returns a score for each example, but the association between examples and scores is lost. In - * cases where we need to know the score for particular examples, use {@link ScoreExamplesWithKeyFunction} - * @author Alex Black - * @see ScoreExamplesWithKeyFunction - */ -class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter, Double> { - protected static final Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class); +@Slf4j +public class ScoreExamplesFunction implements DoubleFlatMapFunction> { private final Broadcast params; private final Broadcast jsonConfig; private final boolean addRegularization; private final int batchSize; - public ScoreExamplesFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; @@ -75,9 +60,9 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter call(Iterator iterator) throws Exception { + public Iterator call(Iterator iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); @@ -121,6 +106,6 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter Type of key, associated with each example. Used to keep track of which score belongs to which example * @see ScoreExamplesFunction */ -public class ScoreExamplesWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesWithKeyFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/**Function to score examples individually, where each example is associated with a particular key
- * Note that scoring is batched for computational efficiency.
- * This is the Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
- * Note: The MultiDataSet objects passed in must have exactly one example in them (otherwise: can't have a 1:1 association - * between keys and data sets to score) - * @author Alex Black - * @param Type of key, associated with each example. Used to keep track of which score belongs to which example - * @see ScoreExamplesFunction - */ -class ScoreExamplesWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { - - protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class); +@Slf4j +public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction>, K, Double> { private final Broadcast params; private final Broadcast jsonConfig; @@ -78,7 +55,7 @@ class ScoreExamplesWithKeyFunctionAdapter * @param addRegularizationTerms if true: add regularization terms (l1/l2) if applicable; false: don't add regularization terms * @param batchSize Batch size to use when scoring examples */ - public ScoreExamplesWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; @@ -88,9 +65,9 @@ class ScoreExamplesWithKeyFunctionAdapter @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); @@ -140,6 +117,6 @@ class ScoreExamplesWithKeyFunctionAdapter log.debug("Scored {} examples ", totalCount); } - return ret; + return ret.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java index 240aeb01e..3fdc7fd1c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -37,35 +36,23 @@ import java.util.Iterator; import java.util.List; /** Function used to score a DataSet using a ComputationGraph */ -public class ScoreFlatMapFunctionCGDataSet - extends BaseFlatMapFunctionAdaptee, Tuple2> { - - public ScoreFlatMapFunctionCGDataSet(String json, Broadcast params, int minibatchSize) { - super(new ScoreFlatMapFunctionCGDataSetAdapter(json, params, minibatchSize)); - } -} - - -/** Function used to score a DataSet using a ComputationGraph */ -class ScoreFlatMapFunctionCGDataSetAdapter - implements FlatMapFunctionAdapter, Tuple2> { - +public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGDataSet.class); private String json; private Broadcast params; private int minibatchSize; - public ScoreFlatMapFunctionCGDataSetAdapter(String json, Broadcast params, int minibatchSize) { + public ScoreFlatMapFunctionCGDataSet(String json, Broadcast params, int minibatchSize) { this.json = json; this.params = params; this.minibatchSize = minibatchSize; } @Override - public Iterable> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)); + return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); } DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -90,6 +77,6 @@ class ScoreFlatMapFunctionCGDataSetAdapter Nd4j.getExecutioner().commit(); - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java index 91072942d..bf9e3f596 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -29,7 +28,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; -import lombok.val; import java.util.ArrayList; import java.util.Collections; @@ -37,18 +35,7 @@ import java.util.Iterator; import java.util.List; /** Function used to score a MultiDataSet using a given ComputationGraph */ -public class ScoreFlatMapFunctionCGMultiDataSet - extends BaseFlatMapFunctionAdaptee, Tuple2> { - - public ScoreFlatMapFunctionCGMultiDataSet(String json, Broadcast params, int minibatchSize) { - super(new ScoreFlatMapFunctionCGMultiDataSetAdapter(json, params, minibatchSize)); - } -} - - -/** Function used to score a MultiDataSet using a given ComputationGraph */ -class ScoreFlatMapFunctionCGMultiDataSetAdapter - implements FlatMapFunctionAdapter, Tuple2> { +public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGMultiDataSet.class); private String json; @@ -56,16 +43,16 @@ class ScoreFlatMapFunctionCGMultiDataSetAdapter private int minibatchSize; - public ScoreFlatMapFunctionCGMultiDataSetAdapter(String json, Broadcast params, int minibatchSize) { + public ScoreFlatMapFunctionCGMultiDataSet(String json, Broadcast params, int minibatchSize) { this.json = json; this.params = params; this.minibatchSize = minibatchSize; } @Override - public Iterable> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)); + return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); } MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -91,6 +78,6 @@ class ScoreFlatMapFunctionCGMultiDataSetAdapter Nd4j.getExecutioner().commit(); - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java index f95cc46f2..0a33fb995 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.impl.multilayer.evaluation; import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.impl.evaluation.EvaluationRunner; import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -36,25 +35,7 @@ import java.util.concurrent.Future; * * @author Alex Black */ -public class IEvaluateFlatMapFunction - extends BaseFlatMapFunctionAdaptee, T[]> { - - public IEvaluateFlatMapFunction(boolean isCompGraph, Broadcast json, Broadcast params, - int evalNumWorkers, int evalBatchSize, T... evaluations) { - super(new IEvaluateFlatMapFunctionAdapter<>(isCompGraph, json, params, evalNumWorkers, evalBatchSize, evaluations)); - } -} - - -/** - * Function to evaluate data (using an IEvaluation instance), in a distributed manner - * Flat map function used to batch examples for computational efficiency + reduce number of IEvaluation objects returned - * for network efficiency. - * - * @author Alex Black - */ -@Slf4j -class IEvaluateFlatMapFunctionAdapter implements FlatMapFunctionAdapter, T[]> { +public class IEvaluateFlatMapFunction implements FlatMapFunction, T[]> { protected boolean isCompGraph; protected Broadcast json; @@ -70,7 +51,7 @@ class IEvaluateFlatMapFunctionAdapter implements FlatMapF * this. Used to avoid doing too many at once (and hence memory issues) * @param evaluations Initial evaulation instance (i.e., empty Evaluation or RegressionEvaluation instance) */ - public IEvaluateFlatMapFunctionAdapter(boolean isCompGraph, Broadcast json, Broadcast params, + public IEvaluateFlatMapFunction(boolean isCompGraph, Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, T[] evaluations) { this.isCompGraph = isCompGraph; this.json = json; @@ -81,9 +62,9 @@ class IEvaluateFlatMapFunctionAdapter implements FlatMapF } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } Future f = EvaluationRunner.getInstance().execute( @@ -91,9 +72,9 @@ class IEvaluateFlatMapFunctionAdapter implements FlatMapF IEvaluation[] result = f.get(); if(result == null){ - return Collections.emptyList(); + return Collections.emptyIterator(); } else { - return Collections.singletonList((T[])result); + return Collections.singletonList((T[])result).iterator(); } } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java index 2bdf926b6..03e4e55cf 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java @@ -16,17 +16,15 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSetUtil; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; -import org.nd4j.linalg.util.DataSetUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; @@ -44,23 +42,7 @@ import java.util.List; * @author Alex Black */ public class FeedForwardWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>>, K, INDArray> { - - public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { - super(new FeedForwardWithKeyFunctionAdapter(params, jsonConfig, batchSize)); - } -} - - -/** - * Function to feed-forward examples, and get the network output (for example, class probabilities). - * A key value is used to keey track of which output corresponds to which input. - * - * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example - * @author Alex Black - */ -class FeedForwardWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>>, Tuple2> { + implements PairFlatMapFunction>>, K, INDArray> { protected static Logger log = LoggerFactory.getLogger(FeedForwardWithKeyFunction.class); @@ -73,7 +55,7 @@ class FeedForwardWithKeyFunctionAdapter * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use for forward pass (use > 1 for efficiency) */ - public FeedForwardWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { + public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.batchSize = batchSize; @@ -81,9 +63,9 @@ class FeedForwardWithKeyFunctionAdapter @Override - public Iterable> call(Iterator>> iterator) throws Exception { + public Iterator> call(Iterator>> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); @@ -129,7 +111,7 @@ class FeedForwardWithKeyFunctionAdapter } if (tupleCount == 0) { - return Collections.emptyList(); + return Collections.emptyIterator(); } List> output = new ArrayList<>(tupleCount); @@ -185,7 +167,7 @@ class FeedForwardWithKeyFunctionAdapter Nd4j.getExecutioner().commit(); - return output; + return output.iterator(); } private INDArray getSubset(int exampleStart, int exampleEnd, INDArray from) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java index 0b3383381..4142750d0 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java @@ -16,11 +16,11 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; +import org.apache.spark.api.java.function.DoubleFlatMapFunction; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.util.BaseDoubleFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -39,23 +39,7 @@ import java.util.List; * @author Alex Black * @see ScoreExamplesWithKeyFunction */ -public class ScoreExamplesFunction extends BaseDoubleFlatMapFunctionAdaptee> { - - public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/**Function to score examples individually. Note that scoring is batched for computational efficiency.
- * This is essentially a Spark implementation of the {@link MultiLayerNetwork#scoreExamples(DataSet, boolean)} method
- * Note: This method returns a score for each example, but the association between examples and scores is lost. In - * cases where we need to know the score for particular examples, use {@link ScoreExamplesWithKeyFunction} - * @author Alex Black - * @see ScoreExamplesWithKeyFunction - */ -class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter, Double> { +public class ScoreExamplesFunction implements DoubleFlatMapFunction> { protected static Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class); @@ -64,7 +48,7 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter params, Broadcast jsonConfig, + public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; @@ -74,9 +58,9 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter call(Iterator iterator) throws Exception { + public Iterator call(Iterator iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); @@ -119,6 +103,6 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesWithKeyFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/** - * Function to score examples individually, where each example is associated with a particular key
- * Note that scoring is batched for computational efficiency.
- * This is the Spark implementation of t he {@link MultiLayerNetwork#scoreExamples(DataSet, boolean)} method
- * Note: The DataSet objects passed in must have exactly one example in them (otherwise: can't have a 1:1 association - * between keys and data sets to score) - * - * @param Type of key, associated with each example. Used to keep track of which score belongs to which example - * @author Alex Black - * @see ScoreExamplesFunction - */ -class ScoreExamplesWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { - - protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class); +@Slf4j +public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction>, K, Double> { private final Broadcast params; private final Broadcast jsonConfig; @@ -81,8 +56,7 @@ class ScoreExamplesWithKeyFunctionAdapter * @param addRegularizationTerms if true: add regularization terms (L1, L2) to the score * @param batchSize Batch size to use when scoring */ - public ScoreExamplesWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { + public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.addRegularization = addRegularizationTerms; @@ -91,9 +65,9 @@ class ScoreExamplesWithKeyFunctionAdapter @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); @@ -143,6 +117,6 @@ class ScoreExamplesWithKeyFunctionAdapter log.debug("Scored {} examples ", totalCount); } - return ret; + return ret.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java index 480f1dcd2..8063ba8e3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java @@ -16,9 +16,11 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -26,43 +28,25 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import scala.Tuple2; -import lombok.val; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; -public class ScoreFlatMapFunction extends BaseFlatMapFunctionAdaptee, Tuple2> { - - public ScoreFlatMapFunction(String json, Broadcast params, int minibatchSize) { - super(new ScoreFlatMapFunctionAdapter(json, params, minibatchSize)); - } - -} - - -class ScoreFlatMapFunctionAdapter implements FlatMapFunctionAdapter, Tuple2> { - - private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunction.class); +@Slf4j +@AllArgsConstructor +public class ScoreFlatMapFunction implements FlatMapFunction, Tuple2> { private String json; private Broadcast params; private int minibatchSize; - public ScoreFlatMapFunctionAdapter(String json, Broadcast params, int minibatchSize) { - this.json = json; - this.params = params; - this.minibatchSize = minibatchSize; - } - @Override - public Iterable> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)); + return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); } DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -87,6 +71,6 @@ class ScoreFlatMapFunctionAdapter implements FlatMapFunctionAdapter - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, - int batchSize) { - super(new VaeReconstructionErrorWithKeyFunctionAdapter(params, jsonConfig, batchSize)); - } -} - - -/** - * Function to calculate the reconstruction error for a variational autoencoder, that is the first layer in a - * MultiLayerNetwork.
- * Note that the VAE must be using a loss function, not a {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution}
- * Also note that scoring is batched for computational efficiency.
- * - * @author Alex Black - * @see VaeReconstructionProbWithKeyFunction - */ -class VaeReconstructionErrorWithKeyFunctionAdapter extends BaseVaeScoreWithKeyFunctionAdapter { +public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction { /** * @param params MultiLayerNetwork parameters * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use when scoring */ - public VaeReconstructionErrorWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { super(params, jsonConfig, batchSize); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java index 7bba68b28..e8fc8416f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java @@ -21,12 +21,8 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunctionAdapter; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; import org.nd4j.linalg.api.ndarray.INDArray; -import scala.Tuple2; - -import java.util.Iterator; /** @@ -36,25 +32,7 @@ import java.util.Iterator; * * @author Alex Black */ -public class VaeReconstructionProbWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public VaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, - boolean useLogProbability, int batchSize, int numSamples) { - super(new VaeReconstructionProbWithKeyFunctionAdapter(params, jsonConfig, useLogProbability, batchSize, - numSamples)); - } -} - - -/** - * Function to calculate the reconstruction probability for a variational autoencoder, that is the first layer in a - * MultiLayerNetwork.
- * Note that scoring is batched for computational efficiency.
- * - * @author Alex Black - */ -class VaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeReconstructionProbWithKeyFunctionAdapter { +public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunction { /** @@ -64,7 +42,7 @@ class VaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeReconstructi * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} */ - public VaeReconstructionProbWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public VaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean useLogProbability, int batchSize, int numSamples) { super(params, jsonConfig, useLogProbability, batchSize, numSamples); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index 30e6b395f..b6d5654ec 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -67,7 +67,7 @@ import java.io.IOException; import java.io.OutputStream; import java.util.*; -import static com.google.common.base.Preconditions.checkArgument; +import static org.nd4j.shade.guava.base.Preconditions.checkArgument; /** * ParameterAveragingTrainingMaster: A {@link TrainingMaster} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java deleted file mode 100644 index 5f9d25547..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.DoubleFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -/** - * DoubleFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DoubleFlatMapFunction - * - */ -public class BaseDoubleFlatMapFunctionAdaptee implements DoubleFlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseDoubleFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterable call(T t) throws Exception { - return adapter.call(t); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java deleted file mode 100644 index b58cab202..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import scala.Tuple2; - -/** - * PairFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to PairFlatMapFunction - * - */ -public class BasePairFlatMapFunctionAdaptee implements PairFlatMapFunction { - - protected final FlatMapFunctionAdapter> adapter; - - public BasePairFlatMapFunctionAdaptee(FlatMapFunctionAdapter> adapter) { - this.adapter = adapter; - } - - @Override - public Iterable> call(T t) throws Exception { - return adapter.call(t); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java deleted file mode 100644 index 49a05231b..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.DoubleFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Iterator; - -/** - * DoubleFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DoubleFlatMapFunction - * - */ -public class BaseDoubleFlatMapFunctionAdaptee implements DoubleFlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseDoubleFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterator call(T t) throws Exception { - return adapter.call(t).iterator(); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java deleted file mode 100644 index b28a8dcc2..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import scala.Tuple2; - -import java.util.Iterator; - -/** - * PairFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to PairFlatMapFunction - * - */ -public class BasePairFlatMapFunctionAdaptee implements PairFlatMapFunction { - - protected final FlatMapFunctionAdapter> adapter; - - public BasePairFlatMapFunctionAdaptee(FlatMapFunctionAdapter> adapter) { - this.adapter = adapter; - } - - @Override - public Iterator> call(T t) throws Exception { - return adapter.call(t).iterator(); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml index 1919fcb06..f753fefae 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml @@ -36,16 +36,9 @@ UTF-8 UTF-8 - - 2.1.0 - 2 1.0.0_spark_2-SNAPSHOT - 2.1.0 2.11.12 @@ -194,21 +187,6 @@ jaxb-impl ${jaxb.version}
- - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-remote_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - io.netty netty @@ -280,36 +258,6 @@
- - spark-2 - - - spark.major.version - 2 - - - - - com.typesafe.akka - akka-remote_2.11 - 2.3.11 - - - - - cdh5 - - - org.apache.hadoop - https://repository.cloudera.com/artifactory/cloudera-repos/ - - - - 2.0.0-cdh4.6.0 - 1.2.0-cdh5.3.0 - - - test-nd4j-native diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml index f624beb6b..1b4f33c1e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml @@ -130,26 +130,11 @@ ${project.version}
- - com.google.guava - guava - ${guava.version} - com.google.protobuf protobuf-java ${google.protobuf.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - javax.ws.rs javax.ws.rs-api @@ -226,76 +211,6 @@ leveldbjni-all ${leveldb.version} - - com.typesafe.akka - akka-contrib_2.11 - ${akka.version} - - - - - - - com.fasterxml.jackson.core - jackson-core - ${spark.jackson.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${spark.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark.jackson.version} - - - - - com.fasterxml.jackson.module - jackson-module-scala_2.11 - ${spark.jackson.version} - - - com.google.code.findbugs - jsr305 - - - - - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${spark.jackson.version} - - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${spark.jackson.version} - - - - - - - com.typesafe - config - ${typesafe.config.version} - - - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} - com.beust diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java index a8a499ec6..f694291c8 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java @@ -17,10 +17,17 @@ package org.deeplearning4j.ui.api; /** - * Enumeration for the type of function. Mainly used in specifying {@link Route} instances + * Enumeration for the type of function. Mainly used in specifying {@link Route} instances
+ * Supplier: No args
+ * Function: 1 arg
+ * BiFunction: 2 args
+ * Function3: 3 args
+ * Request0Function: Supplier + request, no args (as Function)
+ * Request1Function: Supplier + request + 1 args (as BiFunction)
* * @author Alex Black */ public enum FunctionType { - Supplier, Function, BiFunction, Function3 + Supplier, Function, BiFunction, Function3, + Request0Function, Request1Function } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java index 0f0a731d8..f2dd1e017 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java @@ -18,6 +18,7 @@ package org.deeplearning4j.ui.api; import lombok.AllArgsConstructor; import lombok.Data; +import play.mvc.Http; import play.mvc.Result; import java.util.function.BiFunction; @@ -38,17 +39,27 @@ public class Route { private final Supplier supplier; private final Function function; private final BiFunction function2; + private final Function request0Function; + private final BiFunction request1Function; public Route(String route, HttpMethod method, FunctionType functionType, Supplier supplier) { - this(route, method, functionType, supplier, null, null); + this(route, method, functionType, supplier, null, null, null, null); } public Route(String route, HttpMethod method, FunctionType functionType, Function function) { - this(route, method, functionType, null, function, null); + this(route, method, functionType, null, function, null, null, null); + } + + public static Route request0Function(String route, HttpMethod httpMethod, Function function){ + return new Route(route, httpMethod, FunctionType.Request0Function, null, null, null, function, null); + } + + public static Route request1Function(String route, HttpMethod httpMethod, BiFunction function){ + return new Route(route, httpMethod, FunctionType.Request1Function, null, null, null, null, function); } public Route(String route, HttpMethod method, FunctionType functionType, BiFunction function) { - this(route, method, functionType, null, null, function); + this(route, method, functionType, null, null, function, null, null); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java index fe92756cc..0886af8a9 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java @@ -24,6 +24,7 @@ import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.i18n.I18NResource; +import play.libs.Files; import play.libs.Json; import play.mvc.Http; import play.mvc.Result; @@ -31,9 +32,9 @@ import play.mvc.Results; import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.*; -import static play.mvc.Controller.request; import static play.mvc.Results.badRequest; import static play.mvc.Results.ok; @@ -63,8 +64,8 @@ public class TsneModule implements UIModule { () -> ok(org.deeplearning4j.ui.views.html.tsne.Tsne.apply())); Route r2 = new Route("/tsne/sessions", HttpMethod.GET, FunctionType.Supplier, this::listSessions); Route r3 = new Route("/tsne/coords/:sid", HttpMethod.GET, FunctionType.Function, this::getCoords); - Route r4 = new Route("/tsne/upload", HttpMethod.POST, FunctionType.Supplier, this::uploadFile); - Route r5 = new Route("/tsne/post/:sid", HttpMethod.POST, FunctionType.Function, this::postFile); + Route r4 = Route.request0Function("/tsne/upload", HttpMethod.POST, this::uploadFile); + Route r5 = Route.request1Function("/tsne/post/:sid", HttpMethod.POST, this::postFile); return Arrays.asList(r1, r2, r3, r4, r5); } @@ -106,22 +107,22 @@ public class TsneModule implements UIModule { } } - private Result uploadFile() { - Http.MultipartFormData body = request().body().asMultipartFormData(); - List fileParts = body.getFiles(); + private Result uploadFile(Http.Request request) { + Http.MultipartFormData body = request.body().asMultipartFormData(); + List> fileParts = body.getFiles(); if (fileParts.isEmpty()) { return badRequest("No file uploaded"); } - Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); + Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); String fileName = uploadedFile.getFilename(); String contentType = uploadedFile.getContentType(); - File file = uploadedFile.getFile(); + File file = uploadedFile.getRef().path().toFile(); try { - uploadedFileLines = FileUtils.readLines(file); + uploadedFileLines = FileUtils.readLines(file, StandardCharsets.UTF_8); } catch (IOException e) { return badRequest("Could not read from uploaded file"); } @@ -129,21 +130,21 @@ public class TsneModule implements UIModule { return ok("File uploaded: " + fileName + ", " + contentType + ", " + file); } - private Result postFile(String sid) { + private Result postFile(Http.Request request, String sid) { // System.out.println("POST FILE CALLED: " + sid); - Http.MultipartFormData body = request().body().asMultipartFormData(); - List fileParts = body.getFiles(); + Http.MultipartFormData body = request.body().asMultipartFormData(); + List> fileParts = body.getFiles(); if (fileParts.isEmpty()) { // System.out.println("**** NO FILE ****"); return badRequest("No file uploaded"); } - Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); + Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); String fileName = uploadedFile.getFilename(); String contentType = uploadedFile.getContentType(); - File file = uploadedFile.getFile(); + File file = uploadedFile.getRef().path().toFile(); List lines; try { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java index 747284e7e..c0ebe4ed8 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java @@ -37,18 +37,16 @@ import org.deeplearning4j.ui.module.defaultModule.DefaultModule; import org.deeplearning4j.ui.module.remote.RemoteReceiverModule; import org.deeplearning4j.ui.module.train.TrainModule; import org.deeplearning4j.ui.module.tsne.TsneModule; -import org.deeplearning4j.ui.play.misc.FunctionUtil; import org.deeplearning4j.ui.play.staticroutes.Assets; -import org.deeplearning4j.ui.play.staticroutes.I18NRoute; -import org.deeplearning4j.ui.play.staticroutes.MultiSessionI18NRoute; import org.deeplearning4j.ui.storage.FileStatsStorage; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener; import org.deeplearning4j.util.DL4JFileUtils; import org.nd4j.linalg.function.Function; import org.nd4j.linalg.primitives.Pair; +import play.BuiltInComponents; import play.Mode; -import play.api.routing.Router; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; @@ -60,6 +58,8 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; +import static play.mvc.Results.ok; + /** * A UI server based on the Play framework @@ -166,63 +166,6 @@ public class PlayUIServer extends UIServer { System.exit(1); } - RoutingDsl routingDsl = new RoutingDsl(); - - //Set up index page and assets routing - //The definitions and FunctionUtil may look a bit weird here... this is used to translate implementation independent - // definitions (i.e., Java Supplier, Function etc interfaces) to the Play-specific versions - //This way, routing is not directly dependent ot Play API. Furthermore, Play 2.5 switches to using these Java interfaces - // anyway; thus switching 2.5 should be as simple as removing the FunctionUtil calls... - if (multiSession) { - routingDsl.GET("/setlang/:sessionId/:to").routeTo(FunctionUtil.biFunction(new MultiSessionI18NRoute())); - } else { - routingDsl.GET("/setlang/:to").routeTo(FunctionUtil.function(new I18NRoute())); - } - routingDsl.GET("/assets/*file").routeTo(FunctionUtil.function(new Assets(ASSETS_ROOT_DIRECTORY))); - - uiModules.add(new DefaultModule(multiSession)); //For: navigation page "/" - uiModules.add(new TrainModule(multiSession, statsStorageLoader, this::getAddress)); - uiModules.add(new ConvolutionalListenerModule()); - uiModules.add(new TsneModule()); - uiModules.add(new SameDiffModule()); - remoteReceiverModule = new RemoteReceiverModule(); - uiModules.add(remoteReceiverModule); - - //Check service loader mechanism (Arbiter UI, etc) for modules - modulesViaServiceLoader(uiModules); - - for (UIModule m : uiModules) { - List routes = m.getRoutes(); - for (Route r : routes) { - RoutingDsl.PathPatternMatcher ppm = routingDsl.match(r.getHttpMethod().name(), r.getRoute()); - switch (r.getFunctionType()) { - case Supplier: - ppm.routeTo(FunctionUtil.function0(r.getSupplier())); - break; - case Function: - ppm.routeTo(FunctionUtil.function(r.getFunction())); - break; - case BiFunction: - ppm.routeTo(FunctionUtil.biFunction(r.getFunction2())); - break; - case Function3: - default: - throw new RuntimeException("Not yet implemented"); - } - } - - //Determine which type IDs this module wants to receive: - List typeIDs = m.getCallbackTypeIDs(); - for (String typeID : typeIDs) { - List list = typeIDModuleMap.get(typeID); - if (list == null) { - list = Collections.synchronizedList(new ArrayList<>()); - typeIDModuleMap.put(typeID, list); - } - list.add(m); - } - } - String portProperty = System.getProperty(DL4JSystemProperties.UI_SERVER_PORT_PROPERTY); if (portProperty != null) { try { @@ -233,6 +176,7 @@ public class PlayUIServer extends UIServer { } } + //Set play secret key, if required //http://www.playframework.com/documentation/latest/ApplicationSecret String crypto = System.getProperty("play.crypto.secret"); @@ -245,9 +189,9 @@ public class PlayUIServer extends UIServer { System.setProperty("play.crypto.secret", base64); } - Router router = routingDsl.build(); + try { - server = Server.forRouter(router, Mode.PROD, port); + server = Server.forRouter(Mode.PROD, port, this::createRouter); } catch (Throwable e){ if(e.getMessage().contains("'play.crypto.provider")){ //Usual cause: user's uber-jar does not include application.conf @@ -284,6 +228,79 @@ public class PlayUIServer extends UIServer { setStopped(false); } + protected Router createRouter(BuiltInComponents builtInComponents){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); + + //Set up index page and assets routing + //The definitions and FunctionUtil may look a bit weird here... this is used to translate implementation independent + // definitions (i.e., Java Supplier, Function etc interfaces) to the Play-specific versions + //This way, routing is not directly dependent ot Play API. Furthermore, Play 2.5 switches to using these Java interfaces + // anyway; thus switching 2.5 should be as simple as removing the FunctionUtil calls... + if (multiSession) { + routingDsl.GET("/setlang/:sessionId/:to").routingTo((request, sid, to) -> { + I18NProvider.getInstance(sid.toString()).setDefaultLanguage(to.toString()); + return ok(); + }); + } else { + routingDsl.GET("/setlang/:to").routingTo((request, to) -> { + I18NProvider.getInstance().setDefaultLanguage(to.toString()); + return ok(); + }); + } + routingDsl.GET("/assets/*file").routingTo((request, file) -> Assets.assetRequest(ASSETS_ROOT_DIRECTORY, file.toString())); + + uiModules.add(new DefaultModule(multiSession)); //For: navigation page "/" + uiModules.add(new TrainModule(multiSession, statsStorageLoader, this::getAddress)); + uiModules.add(new ConvolutionalListenerModule()); + uiModules.add(new TsneModule()); + uiModules.add(new SameDiffModule()); + remoteReceiverModule = new RemoteReceiverModule(); + uiModules.add(remoteReceiverModule); + + //Check service loader mechanism (Arbiter UI, etc) for modules + modulesViaServiceLoader(uiModules); + + for (UIModule m : uiModules) { + List routes = m.getRoutes(); + for (Route r : routes) { + RoutingDsl.PathPatternMatcher ppm = routingDsl.match(r.getHttpMethod().name(), r.getRoute()); + switch (r.getFunctionType()) { + case Supplier: + ppm.routingTo(request -> r.getSupplier().get()); + break; + case Function: + ppm.routingTo((request, arg) -> r.getFunction().apply(arg.toString())); + break; + case BiFunction: + ppm.routingTo((request, arg0, arg1) -> r.getFunction2().apply(arg0.toString(), arg1.toString())); + break; + case Request0Function: + ppm.routingTo(request -> r.getRequest0Function().apply(request)); + break; + case Request1Function: + ppm.routingTo((request, arg0) -> r.getRequest1Function().apply(request, arg0.toString())); + break; + case Function3: + default: + throw new RuntimeException("Not yet implemented"); + } + } + + //Determine which type IDs this module wants to receive: + List typeIDs = m.getCallbackTypeIDs(); + for (String typeID : typeIDs) { + List list = typeIDModuleMap.get(typeID); + if (list == null) { + list = Collections.synchronizedList(new ArrayList<>()); + typeIDModuleMap.put(typeID, list); + } + list.add(m); + } + } + Router router = routingDsl.build(); + return router; + } + @Override public String getAddress() { String addr = server.mainAddress().toString(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/misc/FunctionUtil.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/misc/FunctionUtil.java deleted file mode 100644 index d1b87c443..000000000 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/misc/FunctionUtil.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.ui.play.misc; - -import play.libs.F; -import play.mvc.Result; - -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * Utility methods for Routing - * - * @author Alex Black - */ -public class FunctionUtil { - - public static F.Function0 function0(Supplier supplier) { - return supplier::get; - } - - public static F.Function function(Function function) { - return function::apply; - } - - public static F.Function2 biFunction(BiFunction function) { - return function::apply; - } - -} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java index dd100b3d1..76d6e7361 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java @@ -16,18 +16,17 @@ package org.deeplearning4j.ui.play.staticroutes; -import com.google.common.net.HttpHeaders; +import org.nd4j.shade.guava.net.HttpHeaders; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FilenameUtils; import org.nd4j.linalg.io.ClassPathResource; -import play.api.libs.MimeTypes; import play.mvc.Result; +import play.mvc.StaticFileMimeTypes; import java.io.InputStream; -import java.util.function.Function; +import java.util.Optional; -import static play.mvc.Http.Context.Implicit.response; import static play.mvc.Results.ok; /** @@ -37,11 +36,9 @@ import static play.mvc.Results.ok; */ @AllArgsConstructor @Slf4j -public class Assets implements Function { - private final String assetsRootDirectory; +public class Assets { - @Override - public Result apply(String s) { + public static Result assetRequest(String assetsRootDirectory, String s) { String fullPath; if(s.startsWith("webjars/")){ @@ -60,15 +57,12 @@ public class Assets implements Function { String fileName = FilenameUtils.getName(fullPath); - response().setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\""); - scala.Option contentType = MimeTypes.forFileName(fileName); + Optional contentType = StaticFileMimeTypes.fileMimeTypes().forFileName(fileName); String ct; - if (contentType.isDefined()) { - ct = contentType.get(); - } else { - ct = "application/octet-stream"; - } + ct = contentType.orElse("application/octet-stream"); - return ok(inputStream).as(ct); + return ok(inputStream) + .withHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"") + .as(ct); } } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index be75282e5..e2890fc61 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -16,7 +16,7 @@ package org.deeplearning4j.integration; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index 9440a59ae..b1de34a2c 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -17,8 +17,8 @@ package org.deeplearning4j.integration; -import com.google.common.collect.ImmutableSet; -import com.google.common.reflect.ClassPath; +import org.nd4j.shade.guava.collect.ImmutableSet; +import org.nd4j.shade.guava.reflect.ClassPath; import org.deeplearning4j.integration.util.CountingMultiDataSetIterator; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java index 463734b82..a70d8dd2f 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java @@ -16,7 +16,7 @@ package org.deeplearning4j.integration.testcases; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.testcases.misc.CharacterIterator; import org.deeplearning4j.integration.testcases.misc.CompositeMultiDataSetPreProcessor; diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index 954d8341d..a139c4f44 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -174,11 +174,6 @@ slf4j-api ${slf4j.version}
- - com.google.guava - guava - ${guava.version} - junit junit diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java index 1156de102..34b305001 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.listeners; -import com.google.common.collect.Sets; +import org.nd4j.shade.guava.collect.Sets; import java.util.Arrays; import java.util.HashSet; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java index 1932a6c75..7dbb0119d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java @@ -1,7 +1,7 @@ package org.nd4j.autodiff.listeners.checkpoint; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java index 36334b648..b063e18a2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java @@ -16,10 +16,10 @@ package org.nd4j.autodiff.listeners.records; -import com.google.common.base.Predicates; -import com.google.common.collect.Collections2; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.base.Predicates; +import org.nd4j.shade.guava.collect.Collections2; +import org.nd4j.shade.guava.collect.ImmutableMap; +import org.nd4j.shade.guava.collect.Lists; import java.util.ArrayList; import java.util.Collection; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index e6f30d12e..e09ceda75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,11 +16,11 @@ package org.nd4j.autodiff.samediff; -import com.google.common.base.Predicates; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Maps; -import com.google.common.collect.Table; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.base.Predicates; +import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Maps; +import org.nd4j.shade.guava.collect.Table; +import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import lombok.*; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 1bde174c1..7da89aa36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.samediff.ops; -import com.google.common.collect.Sets; +import org.nd4j.shade.guava.collect.Sets; import java.util.HashMap; import java.util.Map; import java.util.Set; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index ef1bfefb7..6faf29bfc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.samediff.serde; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import java.nio.ByteOrder; import java.util.*; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 3e329ad13..5bc175952 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -16,8 +16,8 @@ package org.nd4j.autodiff.validation; -import com.google.common.collect.ImmutableSet; -import com.google.common.reflect.ClassPath; +import org.nd4j.shade.guava.collect.ImmutableSet; +import org.nd4j.shade.guava.reflect.ClassPath; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -560,7 +560,7 @@ public class OpValidation { ImmutableSet info; try { //Dependency note: this ClassPath class was added in Guava 14 - info = com.google.common.reflect.ClassPath.from(DifferentialFunctionClassHolder.class.getClassLoader()) + info = org.nd4j.shade.guava.reflect.ClassPath.from(DifferentialFunctionClassHolder.class.getClassLoader()) .getTopLevelClassesRecursive("org.nd4j.linalg.api.ops"); } catch (IOException e) { //Should never happen diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java index e07b5c9b8..fd08e4270 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java @@ -42,6 +42,7 @@ import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.databind.module.SimpleModule; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; @@ -80,7 +81,8 @@ public abstract class BaseEvaluation implements IEvalu .withFieldVisibility(JsonAutoDetect.Visibility.ANY) .withGetterVisibility(JsonAutoDetect.Visibility.NONE) .withSetterVisibility(JsonAutoDetect.Visibility.NONE) - .withCreatorVisibility(JsonAutoDetect.Visibility.NONE)); + .withCreatorVisibility(JsonAutoDetect.Visibility.ANY) + ); return ret; } @@ -107,15 +109,15 @@ public abstract class BaseEvaluation implements IEvalu public static T fromJson(String json, Class clazz) { try { return objectMapper.readValue(json, clazz); - } catch (IllegalArgumentException e) { - if (e.getMessage().contains("Invalid type id")) { + } catch (InvalidTypeIdException e) { + if (e.getMessage().contains("Could not resolve type id")) { try { return (T) attempFromLegacyFromJson(json, e); } catch (Throwable t) { throw new RuntimeException("Cannot deserialize from JSON - JSON is invalid?", t); } } - throw e; + throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } @@ -129,7 +131,7 @@ public abstract class BaseEvaluation implements IEvalu * @param json JSON to attempt to deserialize * @param originalException Original exception to be re-thrown if it isn't legacy JSON */ - protected static T attempFromLegacyFromJson(String json, IllegalArgumentException originalException) { + protected static T attempFromLegacyFromJson(String json, InvalidTypeIdException originalException) throws InvalidTypeIdException { if (json.contains("org.deeplearning4j.eval.Evaluation")) { String newJson = json.replaceAll("org.deeplearning4j.eval.Evaluation", "org.nd4j.evaluation.classification.Evaluation"); return (T) fromJson(newJson, Evaluation.class); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java index dafe571f0..01fd322e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java @@ -16,8 +16,8 @@ package org.nd4j.evaluation.classification; -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; +import org.nd4j.shade.guava.collect.HashMultiset; +import org.nd4j.shade.guava.collect.Multiset; import lombok.Getter; import java.io.Serializable; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java index fed8c63b4..26ef8bba9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java @@ -16,7 +16,7 @@ package org.nd4j.evaluation.custom; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import java.io.Serializable; import java.util.ArrayList; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java index cbad73da3..079755055 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java @@ -16,7 +16,7 @@ package org.nd4j.evaluation.custom; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java index 6acc7fc4e..a8f5c32e9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java @@ -16,7 +16,7 @@ package org.nd4j.evaluation.serde; -import com.google.common.collect.Multiset; +import org.nd4j.shade.guava.collect.Multiset; import org.nd4j.evaluation.classification.ConfusionMatrix; import org.nd4j.shade.jackson.core.JsonGenerator; import org.nd4j.shade.jackson.core.JsonProcessingException; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java index 719ac792d..7a651fb88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java @@ -18,9 +18,9 @@ package org.nd4j.imports.graphmapper.onnx; import org.nd4j.shade.protobuf.ByteString; import org.nd4j.shade.protobuf.Message; -import com.google.common.primitives.Floats; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Floats; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index f57fef4c7..3ad3267c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -17,8 +17,8 @@ package org.nd4j.imports.graphmapper.tf; import org.nd4j.shade.protobuf.Message; -import com.google.common.primitives.Floats; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Floats; +import org.nd4j.shade.guava.primitives.Ints; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 04a61f9d0..50ab2727d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -17,8 +17,8 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import com.google.flatbuffers.FlatBufferBuilder; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 0d7aca8e0..952c4b777 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import net.ericaro.neoitertools.Generator; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java index 1e85be0cd..77463c456 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java @@ -16,9 +16,9 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import com.google.flatbuffers.FlatBufferBuilder; import net.ericaro.neoitertools.Generator; import org.nd4j.linalg.api.blas.params.MMulTranspose; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java index bd0f7c905..dd17fa635 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.*; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 7fc0679db..10c26d29e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 27e8ae281..d2190098c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -16,9 +16,9 @@ package org.nd4j.linalg.api.ops; -import com.google.common.collect.Lists; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.collect.Lists; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Longs; import lombok.*; import lombok.extern.slf4j.Slf4j; import onnx.Onnx; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java index 8e352be13..9ae09a57d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.aggregates; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java index 331dea887..e94a7bc54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.Getter; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index 3de44537a..d0c1bae38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.api.ops.impl.reduce; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NoArgsConstructor; import lombok.val; import onnx.Onnx; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index b9687e598..782c70859 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java index 28151a899..0a2ab4f20 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index bbe133dbb..ef4331b0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -17,8 +17,8 @@ package org.nd4j.linalg.api.shape; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; import lombok.val; import org.nd4j.base.Preconditions; @@ -31,8 +31,6 @@ import org.nd4j.linalg.api.shape.options.ArrayType; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import java.nio.*; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java index 0636cc75b..8eff5cf5d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.dataset; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; +import org.nd4j.shade.guava.collect.Lists; +import org.nd4j.shade.guava.collect.Maps; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java index a56d39567..7b208f4f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.dataset; -import com.google.common.base.Function; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.base.Function; +import org.nd4j.shade.guava.collect.Lists; import lombok.extern.slf4j.Slf4j; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java index a5a46dbcd..7358da385 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.dataset.api; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.SplitTestAndTrain; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 487196912..1edf0d651 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -1281,7 +1281,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { public INDArray scalar(Number value) { MemoryWorkspace ws = Nd4j.getMemoryManager().getCurrentWorkspace(); - if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */ + if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends org.nd4j.shade.guava.util.concurrent.AtomicDouble */ return scalar(value.doubleValue()); else if (value instanceof Float) return scalar(value.floatValue()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index c8baedfa5..0bbf69ebe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.factory; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; import lombok.val; import lombok.var; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java index f715f93ae..3ca99c50e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/Indices.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.indexing; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/IntervalIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/IntervalIndex.java index 68190f4e5..b055430a7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/IntervalIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/IntervalIndex.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.indexing; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Longs; import lombok.Getter; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java index da0dafb72..c21993548 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.indexing; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java index 283db3e0f..33cb996e9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.indexing; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Longs; import lombok.EqualsAndHashCode; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java index 6c1b93b8f..7b56f9eed 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.indexing; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Longs; import lombok.Data; import net.ericaro.neoitertools.Generator; import net.ericaro.neoitertools.Itertools; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Condition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Condition.java index 97a45f2c2..3b4245fef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Condition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Condition.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.indexing.conditions; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; /** * Condition for boolean indexing diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Identity.java index 83ab43320..0413e2505 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Identity.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.indexing.functions; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; /** * Created by agibsonccc on 10/8/14. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java index 687fa53ca..e044dcda8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.indexing.functions; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; import org.nd4j.linalg.factory.Nd4j; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java index fc305d0b3..8a6a6436f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.indexing.functions; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; /** * Created by agibsonccc on 10/8/14. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Zero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Zero.java index fac1d0afc..9bcb3c463 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Zero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/functions/Zero.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.indexing.functions; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; /** * Created by agibsonccc on 10/8/14. diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java index 36d8e05fb..abb919e8c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java @@ -16,7 +16,7 @@ package org.nd4j.jita.handler; -import com.google.common.collect.Table; +import org.nd4j.shade.guava.collect.Table; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.enums.AllocationStatus; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index f6a0eafc0..fdd40f8cb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -16,8 +16,8 @@ package org.nd4j.jita.handler.impl; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Table; +import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Table; import lombok.Getter; import lombok.NonNull; import lombok.val; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java index 6c8ea85e1..1922d9ced 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java @@ -17,8 +17,8 @@ package org.nd4j.linalg.jcublas.util; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.Multimap; +import org.nd4j.shade.guava.collect.ArrayListMultimap; +import org.nd4j.shade.guava.collect.Multimap; import lombok.AllArgsConstructor; import lombok.Data; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java index 0904bdaee..960bbd646 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java @@ -16,7 +16,6 @@ package org.nd4j.linalg.api.indexing; -import org.joda.time.Interval; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; diff --git a/nd4j/nd4j-common/pom.xml b/nd4j/nd4j-common/pom.xml index d423e8012..d75c82cf4 100644 --- a/nd4j/nd4j-common/pom.xml +++ b/nd4j/nd4j-common/pom.xml @@ -83,6 +83,12 @@ ${project.version} + + org.nd4j + guava + ${project.version} + + org.slf4j slf4j-api @@ -117,11 +123,6 @@ ${commons-compress.version} - - com.google.guava - guava - - commons-codec commons-codec diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java index f5ba64f54..72409e3ca 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.collection; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import lombok.Getter; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java index 1f5c1f9cb..e17fec5ce 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java @@ -24,7 +24,7 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; @JsonSerialize(using = JsonSerializerAtomicDouble.class) @JsonDeserialize(using = JsonDeserializerAtomicDouble.class) -public class AtomicDouble extends com.google.common.util.concurrent.AtomicDouble { +public class AtomicDouble extends org.nd4j.shade.guava.util.concurrent.AtomicDouble { public AtomicDouble(){ this(0.0); @@ -40,7 +40,7 @@ public class AtomicDouble extends com.google.common.util.concurrent.AtomicDouble @Override public boolean equals(Object o){ - //NOTE: com.google.common.util.concurrent.AtomicDouble extends Number, hence this class extends number + //NOTE: org.nd4j.shade.guava.util.concurrent.AtomicDouble extends Number, hence this class extends number if(o instanceof Number){ return get() == ((Number)o).doubleValue(); } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index e51e75ce4..2fe33dfba 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.util; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java index e08605881..0c4377ea9 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.util; -import com.google.common.collect.Table; +import org.nd4j.shade.guava.collect.Table; import java.util.Collection; import java.util.Map; diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java index 83799381b..f8fca14b5 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java @@ -1,6 +1,6 @@ package org.nd4j.resources.strumpf; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java index e7c8d977b..1b420d6a4 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java @@ -16,7 +16,7 @@ package org.nd4j.parameterserver.distributed.transport; -import com.google.common.math.IntMath; +import org.nd4j.shade.guava.math.IntMath; import io.aeron.Aeron; import io.aeron.FragmentAssembler; import io.aeron.Publication; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java index a34720093..70228b987 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java @@ -16,7 +16,7 @@ package org.nd4j.parameterserver.distributed.v2.transport.impl; -import com.google.common.math.IntMath; +import org.nd4j.shade.guava.math.IntMath; import io.aeron.Aeron; import io.aeron.FragmentAssembler; import io.aeron.Publication; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index 0a75e6171..dd50f938e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -47,16 +47,6 @@ nd4j-parameter-server ${project.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - joda-time joda-time diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java index e5152e327..e8b149f67 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java @@ -22,17 +22,14 @@ import org.nd4j.parameterserver.model.MasterStatus; import org.nd4j.parameterserver.model.ServerTypeJson; import org.nd4j.parameterserver.model.SlaveStatus; import org.nd4j.parameterserver.model.SubscriberState; +import play.BuiltInComponents; import play.Mode; -import play.libs.F; import play.libs.Json; -import play.mvc.Result; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; -import java.util.List; - import static play.libs.Json.toJson; -import static play.mvc.Controller.request; import static play.mvc.Results.ok; @@ -70,74 +67,35 @@ public class StatusServer { */ public static Server startServer(StatusStorage statusStorage, int statusServerPort) { log.info("Starting server on port " + statusServerPort); - RoutingDsl dsl = new RoutingDsl(); - dsl.GET("/ids/").routeTo(new F.Function0() { - - @Override - public Result apply() throws Throwable { - List ids = statusStorage.ids(); - return ok(toJson(ids)); - } - }); - - - dsl.GET("/state/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return ok(toJson(statusStorage.getState(Integer.parseInt(id)))); - } - }); - - dsl.GET("/opType/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return ok(toJson(ServerTypeJson.builder() - .type(statusStorage.getState(Integer.parseInt(id)).serverType()))); - } - }); - - - dsl.GET("/started/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return statusStorage.getState(Integer.parseInt(id)).isMaster() - ? ok(toJson(MasterStatus.builder() - .master(statusStorage.getState(Integer.parseInt(id)).getServerState()) - //note here that a responder is id + 1 - .responder(statusStorage - .getState(Integer.parseInt(id) + 1).getServerState()) - .responderN(statusStorage - .getState(Integer.parseInt(id)).getTotalUpdates()) - .build())) - : ok(toJson(SlaveStatus.builder() - .slave(statusStorage.getState(Integer.parseInt(id)).serverType()) - .build())); - } - }); - - - - dsl.GET("/connectioninfo/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return ok(toJson(statusStorage.getState(Integer.parseInt(id)).getConnectionInfo())); - } - }); - - dsl.POST("/updatestatus/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - SubscriberState subscriberState = Json.fromJson(request().body().asJson(), SubscriberState.class); - statusStorage.updateState(subscriberState); - return ok(toJson(subscriberState)); - } - }); - - Server server = Server.forRouter(dsl.build(), Mode.PROD, statusServerPort); - - return server; - + return Server.forRouter(Mode.PROD, statusServerPort, builtInComponents -> createRouter(statusStorage, builtInComponents)); } + protected static Router createRouter(StatusStorage statusStorage, BuiltInComponents builtInComponents){ + RoutingDsl dsl = RoutingDsl.fromComponents(builtInComponents); + dsl.GET("/ids/").routingTo(request -> ok(toJson(statusStorage.ids()))); + dsl.GET("/state/:id").routingTo((request, id) -> ok(toJson(statusStorage.getState(Integer.parseInt(id.toString()))))); + dsl.GET("/opType/:id").routingTo((request, id) -> ok(toJson(ServerTypeJson.builder() + .type(statusStorage.getState(Integer.parseInt(id.toString())).serverType())))); + dsl.GET("/started/:id").routingTo((request, id) -> { + boolean isMaster = statusStorage.getState(Integer.parseInt(id.toString())).isMaster(); + if(isMaster){ + return ok(toJson(MasterStatus.builder().master(statusStorage.getState(Integer.parseInt(id.toString())).getServerState()) + //note here that a responder is id + 1 + .responder(statusStorage.getState(Integer.parseInt(id.toString()) + 1).getServerState()) + .responderN(statusStorage.getState(Integer.parseInt(id.toString())).getTotalUpdates()) + .build())); + } else { + return ok(toJson(SlaveStatus.builder().slave(statusStorage.getState(Integer.parseInt(id.toString())).serverType()).build())); + } + }); + dsl.GET("/connectioninfo/:id").routingTo((request, id) -> ok(toJson(statusStorage.getState(Integer.parseInt(id.toString())).getConnectionInfo()))); + dsl.POST("/updatestatus/:id").routingTo((request, id) -> { + SubscriberState subscriberState = Json.fromJson(request.body().asJson(), SubscriberState.class); + statusStorage.updateState(subscriberState); + return ok(toJson(subscriberState)); + }); + + return dsl.build(); + } } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index f69b9c24b..e5ccca909 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -20,7 +20,7 @@ import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import com.beust.jcommander.Parameters; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.jackson.databind.ObjectMapper; import com.mashape.unirest.http.Unirest; diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java index 426131d4d..8ea228586 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java @@ -16,7 +16,7 @@ package org.nd4j.aeron.ipc.chunk; -import com.google.common.collect.Maps; +import org.nd4j.shade.guava.collect.Maps; import lombok.extern.slf4j.Slf4j; import org.nd4j.aeron.ipc.NDArrayMessage; diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index 99f744587..4e4ba462e 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -49,36 +49,6 @@ joda-time ${jodatime.version} - - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-xml - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-joda - ${spark2.jackson.version} - org.apache.arrow arrow-vector diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index 1e8c0f509..82770d51a 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -21,9 +21,6 @@ org.nd4j 1.0.0-SNAPSHOT - - 2.8.0 - 4.0.0 nd4j-gson diff --git a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java b/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java index fc04af6d2..5b087a055 100644 --- a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java +++ b/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java @@ -16,8 +16,8 @@ package org.nd4j.serde.gson; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonParser; diff --git a/nd4j/nd4j-shade/guava/pom.xml b/nd4j/nd4j-shade/guava/pom.xml new file mode 100644 index 000000000..73b8d5825 --- /dev/null +++ b/nd4j/nd4j-shade/guava/pom.xml @@ -0,0 +1,219 @@ + + + + nd4j-shade + org.nd4j + 1.0.0-SNAPSHOT + + 4.0.0 + + guava + + + true + + + + + com.google.guava + guava + 28.0-jre + + true + + + + + + + custom-lifecycle + + + !skip.custom.lifecycle + + + + + + org.apache.portals.jetspeed-2 + jetspeed-mvn-maven-plugin + 2.3.1 + + + compile-and-pack + compile + + mvn + + + + + + org.apache.maven.shared + maven-invoker + 2.2 + + + + + + + create-shaded-jars + @rootdir@/nd4j/nd4j-shade/guava/ + clean,compile,package + + true + + + + + create-shaded-jars + + + + + + + + + + + + + com.lewisd + lint-maven-plugin + 0.0.11 + + + pom-lint + none + + + + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade-plugin.version} + + + package + + shade + + + + + reference.conf + + + + + + + + + + + false + true + true + + + + com.google.*:* + + + + + + com.google.common + org.nd4j.shade.guava + + + com.google + org.nd4j.shade + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + true + + + + empty-javadoc-jar + package + + jar + + + javadoc + ${basedir}/javadoc + + + + empty-sources-jar + package + + jar + + + sources + ${basedir}/src + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + 3.0.0 + + + unpack + package + + unpack + + + + + org.nd4j + guava + ${project.version} + jar + false + ${project.build.directory}/classes/ + **/*.class,**/*.xml + + + + + + + + + + \ No newline at end of file diff --git a/nd4j/nd4j-shade/jackson/pom.xml b/nd4j/nd4j-shade/jackson/pom.xml index 1d53e1c41..ad2be71fb 100644 --- a/nd4j/nd4j-shade/jackson/pom.xml +++ b/nd4j/nd4j-shade/jackson/pom.xml @@ -32,6 +32,79 @@ true + + + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + + true + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.databind.version} + true + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${jackson.version} + true + + + com.fasterxml.jackson.dataformat + jackson-dataformat-xml + ${jackson.version} + true + + + + + jackson-module-jaxb-annotations + com.fasterxml.jackson.module + + + + + com.fasterxml.jackson.datatype + jackson-datatype-joda + ${jackson.version} + true + + + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + true + + + org.yaml + snakeyaml + ${shaded.snakeyaml.version} + true + + + org.codehaus.woodstox + stax2-api + 3.1.4 + true + + + com.fasterxml.woodstox + woodstox-core + 5.1.0 + true + + + + + custom-lifecycle @@ -139,6 +212,9 @@ com.fasterxml.jackson:* com.fasterxml.jackson.*:* + com.fasterxml.woodstox:* + org.yaml*:* + org.codehaus*:* @@ -148,6 +224,20 @@ com.fasterxml.jackson org.nd4j.shade.jackson + + com.ctc.wstx + org.nd4j.shade.wstx + + + + org.yaml + org.nd4j.shade.yaml + + + + org.codehaus + org.nd4j.shade.codehaus + @@ -214,45 +304,4 @@ - - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-xml - ${jackson.version} - - - - - jackson-module-jaxb-annotations - com.fasterxml.jackson.module - - - - - com.fasterxml.jackson.datatype - jackson-datatype-joda - ${jackson.version} - - - - - - diff --git a/nd4j/nd4j-shade/pom.xml b/nd4j/nd4j-shade/pom.xml index 36b58087b..927292b5f 100644 --- a/nd4j/nd4j-shade/pom.xml +++ b/nd4j/nd4j-shade/pom.xml @@ -30,6 +30,7 @@ jackson protobuf + guava diff --git a/nd4j/nd4j-shade/protobuf/pom.xml b/nd4j/nd4j-shade/protobuf/pom.xml index 1cbd7d5a8..910003683 100644 --- a/nd4j/nd4j-shade/protobuf/pom.xml +++ b/nd4j/nd4j-shade/protobuf/pom.xml @@ -20,11 +20,26 @@ com.google.protobuf protobuf-java 3.8.0 + + true com.google.protobuf protobuf-java-util 3.8.0 + true + + + com.google.guava + guava + + + + + com.google.guava + guava + 26.0-android + true
@@ -150,6 +165,7 @@ com.google.protobuf:* com.google.protobuf.*:* + com.google.guava:* @@ -159,6 +175,11 @@ com.google.protobuf org.nd4j.shade.protobuf + + + com.google.common + org.nd4j.shade.protobuf.common + diff --git a/nd4j/pom.xml b/nd4j/pom.xml index 6c294d7e7..f043d7299 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -77,11 +77,6 @@ slf4j-api ${slf4j.version} - - com.google.guava - guava - ${guava.version} - junit junit diff --git a/nd4s/build.sbt b/nd4s/build.sbt index 701483d0f..1fbac5ae6 100644 --- a/nd4s/build.sbt +++ b/nd4s/build.sbt @@ -38,7 +38,7 @@ lazy val commonSettings = Seq( resolvers in ThisBuild ++= Seq(Opts.resolver.sonatypeSnapshots), nd4jVersion := sys.props.getOrElse("nd4jVersion", default = "1.0.0-SNAPSHOT"), libraryDependencies ++= Seq( - "com.nativelibs4java" %% "scalaxy-loops" % "0.3.4", +// "com.nativelibs4java" %% "scalaxy-loops" % "0.3.4", // "org.nd4j" % "nd4j-api" % nd4jVersion.value, // "org.nd4j" % "nd4j-native-platform" % nd4jVersion.value % Test, "org.scalatest" %% "scalatest" % "2.2.6" % Test, diff --git a/nd4s/pom.xml b/nd4s/pom.xml index 011bc7fbe..7fc3fb7d3 100644 --- a/nd4s/pom.xml +++ b/nd4s/pom.xml @@ -68,11 +68,6 @@ - - com.nativelibs4java - scalaxy-loops_${scala.binary.version} - 0.3.4 - org.nd4j nd4j-api @@ -84,6 +79,12 @@ ${logback.version} test + + junit + junit + ${junit.version} + test + org.scalatest scalatest_${scala.binary.version} @@ -99,7 +100,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.12 + ${breeze.version} test @@ -187,6 +188,7 @@ -deprecation -explaintypes -nobootcp + -usejavacp diff --git a/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala b/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala index 46bba7fff..e6818fbc0 100644 --- a/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala +++ b/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala @@ -21,7 +21,6 @@ import org.nd4j.linalg.api.ops.Op import org.nd4j.linalg.factory.Nd4j import org.nd4s.ops.{ BitFilterOps, FilterOps, FunctionalOpExecutioner, MapOps } -import scalaxy.loops._ import scala.language.postfixOps import scala.util.control.Breaks._ @@ -65,7 +64,7 @@ trait CollectionLikeNDArray[A <: INDArray] { val lv = ev.linearView(underlying) breakable { for { - i <- 0 until lv.length().toInt optimized + i <- 0 until lv.length().toInt } if (!f(ev.get(lv, i))) { result = true break() @@ -81,7 +80,7 @@ trait CollectionLikeNDArray[A <: INDArray] { val lv = ev.linearView(underlying) breakable { for { - i <- 0 until lv.length().toInt optimized + i <- 0 until lv.length().toInt } if (!f(ev.get(lv, i))) { result = false break() diff --git a/pom.xml b/pom.xml index ee24279d4..63054c104 100644 --- a/pom.xml +++ b/pom.xml @@ -261,15 +261,13 @@ 3.4.2 0.8.2.2 - 2.3.16 - 1.3.0 + 1.3.0 0.10.4 1.27 0.8.0 2.2 1.15 3.17 - 2.4.8 0.5.0 2.3.23 2.8.1 @@ -280,6 +278,7 @@ 6.5.7 1.4.9 0.9.10 + 1.0 false false @@ -308,7 +307,7 @@ 1.14.0 ${tensorflow.version}-${javacpp-presets.version} - 1.16.1 + 1.18 3.5 3.6 2.5 @@ -325,13 +324,15 @@ 3.2.2 4.1 + 2.4.3 + 2 2.0.29 1.7.21 4.12 1.2.3 - 2.5.1 - ${jackson.version} - 2.6.5 + 2.9.9 + 2.9.9.3 + 1.23 2.8.7 1.18.2 2.0.0 @@ -339,18 +340,16 @@ 20131018 2.6.1 false - 2.2.0 - - 2.1.0 + 2.2.0 2.16.3 3.4.6 0.5.4 3.0.5 3.15.1 - 2.4.8 - + 2.7.3 2.0 - 20.0 + 28.0-jre + 2.8.0 1.2.0-3f79e055 4.10.0 @@ -386,12 +385,12 @@ 2.2.6 - - 2.10.7 - 2.10 2.11.12 2.11 + + 2.12.9 + 2.12 3.0.5 1.3.0 diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index 497128f30..67b050ba1 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -101,6 +101,12 @@ jackson-databind ${jackson.version} + + + com.google.code.gson + gson + ${gson.version} + diff --git a/scalnet/pom.xml b/scalnet/pom.xml index 5d0f85dd8..a6e220280 100644 --- a/scalnet/pom.xml +++ b/scalnet/pom.xml @@ -88,6 +88,12 @@ ${scalacheck.version} test + + + com.google.code.gson + gson + ${gson.version} +