diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml
index 296af48fd..064dd3ecd 100644
--- a/arbiter/arbiter-core/pom.xml
+++ b/arbiter/arbiter-core/pom.xml
@@ -72,6 +72,11 @@
test
+
+ joda-time
+ joda-time
+ ${jodatime.version}
+
diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java
index 9e50065e6..a9c0933c4 100644
--- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java
+++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.arbiter.optimize.distribution;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import lombok.Getter;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java
index fa503ef6d..0e04a130b 100644
--- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java
+++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.arbiter.optimize.runner;
-import com.google.common.util.concurrent.ListenableFuture;
+import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java
index a3992b09a..6982090f1 100644
--- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java
+++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java
@@ -16,9 +16,9 @@
package org.deeplearning4j.arbiter.optimize.runner;
-import com.google.common.util.concurrent.ListenableFuture;
-import com.google.common.util.concurrent.ListeningExecutorService;
-import com.google.common.util.concurrent.MoreExecutors;
+import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
+import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService;
+import org.nd4j.shade.guava.util.concurrent.MoreExecutors;
import lombok.Setter;
import org.deeplearning4j.arbiter.optimize.api.*;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java
index 9e30a06f6..8cfb07723 100644
--- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java
+++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java
@@ -43,13 +43,15 @@ public class JsonMapper {
mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
+ mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
yamlMapper = new ObjectMapper(new YAMLFactory());
- mapper.registerModule(new JodaModule());
- mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
- mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
- mapper.enable(SerializationFeature.INDENT_OUTPUT);
- mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
- mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
+ yamlMapper.registerModule(new JodaModule());
+ yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+ yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ yamlMapper.enable(SerializationFeature.INDENT_OUTPUT);
+ yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
+ yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
+ yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
}
private JsonMapper() {}
diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java
index f10c593e0..5b35220e9 100644
--- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java
+++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java
@@ -39,6 +39,7 @@ public class YamlMapper {
mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
+ mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
}
diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java
index 3572b187b..6f1b336bb 100644
--- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java
+++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java
@@ -59,6 +59,7 @@ public class TestJson {
om.enable(SerializationFeature.INDENT_OUTPUT);
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
+ om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
return om;
}
diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml
index 77f7e34a9..85afe7a6b 100644
--- a/arbiter/arbiter-deeplearning4j/pom.xml
+++ b/arbiter/arbiter-deeplearning4j/pom.xml
@@ -57,6 +57,12 @@
jackson
${nd4j.version}
+
+
+ com.google.code.gson
+ gson
+ ${gson.version}
+
diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java
index 77c31707a..0a5e33d27 100644
--- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java
+++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.arbiter.layers;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import lombok.AccessLevel;
import lombok.Data;
import lombok.EqualsAndHashCode;
diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml
index 93e955219..56d1013bf 100644
--- a/arbiter/arbiter-ui/pom.xml
+++ b/arbiter/arbiter-ui/pom.xml
@@ -107,12 +107,6 @@
-
- com.google.guava
- guava
- ${guava.version}
-
-
org.deeplearning4j
arbiter-core
diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java
index 491756ae7..0ac5ad383 100644
--- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java
+++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.ui.misc;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
+import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
@@ -45,12 +46,9 @@ public class JsonMapper {
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
mapper.enable(SerializationFeature.INDENT_OUTPUT);
-
- mapper.setVisibilityChecker(mapper.getSerializationConfig().getDefaultVisibilityChecker()
- .withFieldVisibility(JsonAutoDetect.Visibility.ANY)
- .withGetterVisibility(JsonAutoDetect.Visibility.NONE)
- .withSetterVisibility(JsonAutoDetect.Visibility.NONE)
- .withCreatorVisibility(JsonAutoDetect.Visibility.NONE));
+ mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
+ mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
+ mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
return mapper;
}
diff --git a/change-scala-versions.sh b/change-scala-versions.sh
index 5fb40c5cb..8968abbf3 100755
--- a/change-scala-versions.sh
+++ b/change-scala-versions.sh
@@ -20,9 +20,9 @@
set -e
-VALID_VERSIONS=( 2.10 2.11 )
-SCALA_210_VERSION=$(grep -F -m 1 'scala210.version' pom.xml); SCALA_210_VERSION="${SCALA_210_VERSION#*>}"; SCALA_210_VERSION="${SCALA_210_VERSION%<*}";
+VALID_VERSIONS=( 2.11 2.12 )
SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}";
+SCALA_212_VERSION=$(grep -F -m 1 'scala212.version' pom.xml); SCALA_212_VERSION="${SCALA_212_VERSION#*>}"; SCALA_212_VERSION="${SCALA_212_VERSION%<*}";
usage() {
echo "Usage: $(basename $0) [-h|--help]
@@ -45,19 +45,18 @@ check_scala_version() {
exit 1
}
-
check_scala_version "$TO_VERSION"
if [ $TO_VERSION = "2.11" ]; then
- FROM_BINARY="_2\.10"
+ FROM_BINARY="_2\.12"
TO_BINARY="_2\.11"
- FROM_VERSION=$SCALA_210_VERSION
+ FROM_VERSION=$SCALA_212_VERSION
TO_VERSION=$SCALA_211_VERSION
else
FROM_BINARY="_2\.11"
- TO_BINARY="_2\.10"
+ TO_BINARY="_2\.12"
FROM_VERSION=$SCALA_211_VERSION
- TO_VERSION=$SCALA_210_VERSION
+ TO_VERSION=$SCALA_212_VERSION
fi
sed_i() {
@@ -70,35 +69,24 @@ echo "Updating Scala versions in pom.xml files to Scala $1, from $FROM_VERSION t
BASEDIR=$(dirname $0)
-#Artifact ids, ending with "_2.10" or "_2.11". Spark, spark-mllib, kafka, etc.
+#Artifact ids, ending with "_2.11" or "_2.12". Spark, spark-mllib, kafka, etc.
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \;
-#Scala versions, like 2.10
+#Scala versions, like 2.11
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \;
-#Scala binary versions, like 2.10
+#Scala binary versions, like 2.11
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \;
-#Scala versions, like scala-library 2.10.6
+#Scala versions, like scala-library 2.11.12
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \;
-#Scala maven plugin, 2.10
+#Scala maven plugin, 2.11
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \;
-
-#Edge case for Korean NLP artifact not following conventions: https://github.com/deeplearning4j/deeplearning4j/issues/6306
-#https://github.com/deeplearning4j/deeplearning4j/issues/6306
-if [[ $TO_VERSION == 2.11* ]]; then
- sed_i 's/korean-text-scala-2.10<\/artifactId>/korean-text<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
- sed_i 's/4.2.0<\/version>/4.4<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
-else
- sed_i 's/korean-text<\/artifactId>/korean-text-scala-2.10<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
- sed_i 's/4.4<\/version>/4.2.0<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
-fi
-
echo "Done updating Scala versions.";
diff --git a/change-spark-versions.sh b/change-spark-versions.sh
deleted file mode 100755
index 06a9b4d55..000000000
--- a/change-spark-versions.sh
+++ /dev/null
@@ -1,85 +0,0 @@
-#!/usr/bin/env bash
-
-################################################################################
-# Copyright (c) 2015-2018 Skymind, Inc.
-#
-# This program and the accompanying materials are made available under the
-# terms of the Apache License, Version 2.0 which is available at
-# https://www.apache.org/licenses/LICENSE-2.0.
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-#
-# SPDX-License-Identifier: Apache-2.0
-################################################################################
-
-# This shell script is adapted from Apache Flink (in turn, adapted from Apache Spark) some modifications.
-
-set -e
-
-VALID_VERSIONS=( 1 2 )
-SPARK_2_VERSION="2\.1\.0"
-SPARK_1_VERSION="1\.6\.3"
-
-usage() {
- echo "Usage: $(basename $0) [-h|--help]
-where :
- -h| --help Display this help text
- valid spark version values : ${VALID_VERSIONS[*]}
-" 1>&2
- exit 1
-}
-
-if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then
- usage
-fi
-
-TO_VERSION=$1
-
-check_spark_version() {
- for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done
- echo "Invalid Spark version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2
- exit 1
-}
-
-
-check_spark_version "$TO_VERSION"
-
-if [ $TO_VERSION = "2" ]; then
- FROM_BINARY="1"
- TO_BINARY="2"
- FROM_VERSION=$SPARK_1_VERSION
- TO_VERSION=$SPARK_2_VERSION
-else
- FROM_BINARY="2"
- TO_BINARY="1"
- FROM_VERSION=$SPARK_2_VERSION
- TO_VERSION=$SPARK_1_VERSION
-fi
-
-sed_i() {
- sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2"
-}
-
-export -f sed_i
-
-echo "Updating Spark versions in pom.xml files to Spark $1";
-
-BASEDIR=$(dirname $0)
-
-# 1
-find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
- -exec bash -c "sed_i 's/\(spark.major.version>\)'$FROM_BINARY'<\/spark.major.version>/\1'$TO_BINARY'<\/spark.major.version>/g' {}" \;
-
-# 1.6.3
-find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
- -exec bash -c "sed_i 's/\(spark.version>\)'$FROM_VERSION'<\/spark.version>/\1'$TO_VERSION'<\/spark.version>/g' {}" \;
-
-#Spark versions, like xxx_spark_2xxx OR xxx_spark_2xxx
-find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
- -exec bash -c "sed_i 's/\(version>.*_spark_\)'$FROM_BINARY'\(.*\)version>/\1'$TO_BINARY'\2version>/g' {}" \;
-
-echo "Done updating Spark versions.";
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java
index 44098e6cc..32f8f7cc5 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java
@@ -16,7 +16,6 @@
package org.datavec.api.transform;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -27,8 +26,7 @@ import java.util.List;
/**A Transform converts an example to another example, or a sequence to another sequence
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.TransformHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface Transform extends Serializable, ColumnOp {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java
index 24a029179..7c57f4daa 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java
@@ -67,6 +67,7 @@ import org.joda.time.DateTimeZone;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.core.JsonProcessingException;
+import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import java.io.IOException;
import java.io.Serializable;
@@ -417,6 +418,16 @@ public class TransformProcess implements Serializable {
public static TransformProcess fromJson(String json) {
try {
return JsonMappers.getMapper().readValue(json, TransformProcess.class);
+ } catch (InvalidTypeIdException e){
+ if(e.getMessage().contains("@class")){
+ //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
+ try{
+ return JsonMappers.getLegacyMapper().readValue(json, TransformProcess.class);
+ } catch (IOException e2){
+ throw new RuntimeException(e2);
+ }
+ }
+ throw new RuntimeException(e);
} catch (IOException e) {
//TODO proper exception message
throw new RuntimeException(e);
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java
index 6b069a9ec..467db70f0 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java
@@ -23,12 +23,14 @@ import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
import org.datavec.api.transform.metadata.CategoricalMetaData;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
+import org.datavec.api.transform.serde.JsonMappers;
import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer;
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
+import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.databind.node.ArrayNode;
import java.io.IOException;
@@ -116,6 +118,16 @@ public class DataAnalysis implements Serializable {
public static DataAnalysis fromJson(String json) {
try{
return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class);
+ } catch (InvalidTypeIdException e){
+ if(e.getMessage().contains("@class")){
+ try{
+ //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
+ return JsonMappers.getLegacyMapper().readValue(json, DataAnalysis.class);
+ } catch (IOException e2){
+ throw new RuntimeException(e2);
+ }
+ }
+ throw new RuntimeException(e);
} catch (Exception e){
//Legacy format
ObjectMapper om = new JsonSerializer().getObjectMapper();
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java
index ecc333d2d..6156ead40 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java
@@ -21,9 +21,10 @@ import lombok.EqualsAndHashCode;
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
import org.datavec.api.transform.schema.Schema;
+import org.datavec.api.transform.serde.JsonMappers;
import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer;
-import org.nd4j.shade.jackson.databind.ObjectMapper;
+import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import java.io.IOException;
import java.util.List;
@@ -50,6 +51,16 @@ public class SequenceDataAnalysis extends DataAnalysis {
public static SequenceDataAnalysis fromJson(String json){
try{
return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class);
+ } catch (InvalidTypeIdException e){
+ if(e.getMessage().contains("@class")){
+ try{
+ //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
+ return JsonMappers.getLegacyMapper().readValue(json, SequenceDataAnalysis.class);
+ } catch (IOException e2){
+ throw new RuntimeException(e2);
+ }
+ }
+ throw new RuntimeException(e);
} catch (IOException e){
throw new RuntimeException(e);
}
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java
index dd43315d2..c86584ede 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java
@@ -17,7 +17,6 @@
package org.datavec.api.transform.analysis.columns;
import org.datavec.api.transform.ColumnType;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -27,8 +26,7 @@ import java.io.Serializable;
* Interface for column analysis
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.ColumnAnalysisHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ColumnAnalysis extends Serializable {
long getCountTotal();
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java
index 83a96bfcf..6bd5b98ac 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java
@@ -18,7 +18,6 @@ package org.datavec.api.transform.condition;
import org.datavec.api.transform.ColumnOp;
import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -35,8 +34,7 @@ import java.util.List;
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.ConditionHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface Condition extends Serializable, ColumnOp {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java
index 5aa672f76..16870e9f9 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java
@@ -18,7 +18,6 @@ package org.datavec.api.transform.filter;
import org.datavec.api.transform.ColumnOp;
import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -33,8 +32,7 @@ import java.util.List;
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.FilterHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface Filter extends Serializable, ColumnOp {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java
index cde86cf90..6889fbf31 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java
@@ -17,7 +17,6 @@
package org.datavec.api.transform.metadata;
import org.datavec.api.transform.ColumnType;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -32,8 +31,7 @@ import java.io.Serializable;
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.ColumnMetaDataHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ColumnMetaData extends Serializable, Cloneable {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java
index 3341e0b6d..8ef67bacb 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java
@@ -23,8 +23,8 @@ import org.datavec.api.writable.Writable;
import java.util.List;
-import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Preconditions.checkNotNull;
+import static org.nd4j.shade.guava.base.Preconditions.checkArgument;
+import static org.nd4j.shade.guava.base.Preconditions.checkNotNull;
/**
* A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition},
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java
index 5bcf81f7a..1e5177c68 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java
@@ -23,7 +23,6 @@ import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.LongMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.comparator.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonInclude;
@@ -50,8 +49,7 @@ import java.util.List;
@EqualsAndHashCode(exclude = {"inputSchema"})
@JsonIgnoreProperties({"inputSchema"})
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.CalculateSortedRankHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public class CalculateSortedRank implements Serializable, ColumnOp {
private final String newColumnName;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java
index a9e167943..1c16ebcce 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java
@@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.*;
import org.datavec.api.transform.serde.JsonMappers;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.*;
import org.joda.time.DateTimeZone;
import org.nd4j.shade.jackson.annotation.*;
@@ -29,9 +28,11 @@ import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
+import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
+import java.io.IOException;
import java.io.Serializable;
import java.util.*;
@@ -48,8 +49,7 @@ import java.util.*;
*/
@JsonIgnoreProperties({"columnNames", "columnNamesIndex"})
@EqualsAndHashCode
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.SchemaHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Data
public class Schema implements Serializable {
@@ -358,6 +358,16 @@ public class Schema implements Serializable {
public static Schema fromJson(String json) {
try{
return JsonMappers.getMapper().readValue(json, Schema.class);
+ } catch (InvalidTypeIdException e){
+ if(e.getMessage().contains("@class")){
+ try{
+ //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
+ return JsonMappers.getLegacyMapper().readValue(json, Schema.class);
+ } catch (IOException e2){
+ throw new RuntimeException(e2);
+ }
+ }
+ throw new RuntimeException(e);
} catch (Exception e){
//TODO better exceptions
throw new RuntimeException(e);
@@ -379,21 +389,6 @@ public class Schema implements Serializable {
}
}
- private static Schema fromJacksonString(String str, JsonFactory factory) {
- ObjectMapper om = new ObjectMapper(factory);
- om.registerModule(new JodaModule());
- om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
- om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
- om.enable(SerializationFeature.INDENT_OUTPUT);
- om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
- om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
- try {
- return om.readValue(str, Schema.class);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
public static class Builder {
List columnMetaData = new ArrayList<>();
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java
index e4a09f6e9..a5677d1f5 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java
@@ -17,7 +17,6 @@
package org.datavec.api.transform.sequence;
import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -30,8 +29,7 @@ import java.util.List;
* Compare the time steps of a sequence
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.SequenceComparatorHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface SequenceComparator extends Comparator>, Serializable {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java
index bef0b4ccc..3471dbaa3 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java
@@ -17,7 +17,6 @@
package org.datavec.api.transform.sequence;
import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -32,8 +31,7 @@ import java.util.List;
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.SequenceSplitHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface SequenceSplit extends Serializable {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java
index e25f85253..1456af8d7 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java
@@ -17,7 +17,6 @@
package org.datavec.api.transform.sequence.window;
import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -36,8 +35,7 @@ import java.util.List;
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.WindowFunctionHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface WindowFunction extends Serializable {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java
index 652556ce4..bfa114697 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java
@@ -16,44 +16,17 @@
package org.datavec.api.transform.serde;
-import lombok.AllArgsConstructor;
-import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
-import org.datavec.api.io.WritableComparator;
-import org.datavec.api.transform.Transform;
-import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
-import org.datavec.api.transform.condition.column.ColumnCondition;
-import org.datavec.api.transform.filter.Filter;
-import org.datavec.api.transform.metadata.ColumnMetaData;
-import org.datavec.api.transform.rank.CalculateSortedRank;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.sequence.SequenceComparator;
-import org.datavec.api.transform.sequence.SequenceSplit;
-import org.datavec.api.transform.sequence.window.WindowFunction;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
-import org.datavec.api.writable.Writable;
-import org.nd4j.linalg.activations.IActivation;
-import org.nd4j.linalg.lossfunctions.ILossFunction;
-import org.nd4j.linalg.primitives.Pair;
-import org.nd4j.serde.json.LegacyIActivationDeserializer;
-import org.nd4j.serde.json.LegacyILossFunctionDeserializer;
+import org.datavec.api.transform.serde.legacy.LegacyJsonFormat;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
-import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
-import org.nd4j.shade.jackson.databind.*;
-import org.nd4j.shade.jackson.databind.cfg.MapperConfig;
-import org.nd4j.shade.jackson.databind.introspect.Annotated;
-import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
-import org.nd4j.shade.jackson.databind.introspect.AnnotationMap;
-import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector;
-import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder;
+import org.nd4j.shade.jackson.databind.DeserializationFeature;
+import org.nd4j.shade.jackson.databind.MapperFeature;
+import org.nd4j.shade.jackson.databind.ObjectMapper;
+import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
-import java.lang.annotation.Annotation;
-import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
-
/**
* JSON mappers for deserializing neural net configurations, etc.
*
@@ -62,38 +35,9 @@ import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class JsonMappers {
- /**
- * This system property is provided as an alternative to {@link #registerLegacyCustomClassesForJSON(Class[])}
- * Classes can be specified in comma-separated format
- */
- public static String CUSTOM_REGISTRATION_PROPERTY = "org.datavec.config.custom.legacyclasses";
-
- static {
- String p = System.getProperty(CUSTOM_REGISTRATION_PROPERTY);
- if(p != null && !p.isEmpty()){
- String[] split = p.split(",");
- List> list = new ArrayList<>();
- for(String s : split){
- try{
- Class> c = Class.forName(s);
- list.add(c);
- } catch (Throwable t){
- log.warn("Error parsing {} system property: class \"{}\" could not be loaded",CUSTOM_REGISTRATION_PROPERTY, s, t);
- }
- }
-
- if(list.size() > 0){
- try {
- registerLegacyCustomClassesForJSONList(list);
- } catch (Throwable t){
- log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",CUSTOM_REGISTRATION_PROPERTY, t);
- }
- }
- }
- }
-
private static ObjectMapper jsonMapper;
private static ObjectMapper yamlMapper;
+ private static ObjectMapper legacyMapper; //For 1.0.0-alpha and earlier TransformProcess etc
static {
jsonMapper = new ObjectMapper();
@@ -102,117 +46,12 @@ public class JsonMappers {
configureMapper(yamlMapper);
}
- private static Map legacyMappers = new ConcurrentHashMap<>();
-
-
- /**
- * Register a set of classes (Transform, Filter, etc) for JSON deserialization.
- *
- * This is required ONLY when BOTH of the following conditions are met:
- * 1. You want to load a serialized TransformProcess, saved in 1.0.0-alpha or before, AND
- * 2. The serialized TransformProcess has a custom Transform, Filter, etc (i.e., one not defined in DL4J)
- *
- * By passing the classes of these custom classes here, DataVec should be able to deserialize them, in spite of the JSON
- * format change between versions.
- *
- * @param classes Classes to register
- */
- public static void registerLegacyCustomClassesForJSON(Class>... classes) {
- registerLegacyCustomClassesForJSONList(Arrays.>asList(classes));
- }
-
- /**
- * @see #registerLegacyCustomClassesForJSON(Class[])
- */
- public static void registerLegacyCustomClassesForJSONList(List> classes){
- //Default names (i.e., old format for custom JSON format)
- List> list = new ArrayList<>();
- for(Class> c : classes){
- list.add(new Pair(c.getSimpleName(), c));
+ public static synchronized ObjectMapper getLegacyMapper(){
+ if(legacyMapper == null){
+ legacyMapper = LegacyJsonFormat.legacyMapper();
+ configureMapper(legacyMapper);
}
- registerLegacyCustomClassesForJSON(list);
- }
-
- /**
- * Set of classes that can be registered for legacy deserialization.
- */
- private static List> REGISTERABLE_CUSTOM_CLASSES = (List>) Arrays.>asList(
- Transform.class,
- ColumnAnalysis.class,
- ColumnCondition.class,
- Filter.class,
- ColumnMetaData.class,
- CalculateSortedRank.class,
- Schema.class,
- SequenceComparator.class,
- SequenceSplit.class,
- WindowFunction.class,
- Writable.class,
- WritableComparator.class
- );
-
- /**
- * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
- * ONLY) for JSON deserialization, with custom names.
- * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])}
- * but is added in case it is required in non-standard circumstances.
- */
- public static void registerLegacyCustomClassesForJSON(List> classes){
- for(Pair p : classes){
- String s = p.getFirst();
- Class c = p.getRight();
- //Check if it's a valid class to register...
- boolean found = false;
- for( Class> c2 : REGISTERABLE_CUSTOM_CLASSES){
- if(c2.isAssignableFrom(c)){
- Map map = LegacyMappingHelper.legacyMappingForClass(c2);
- map.put(p.getFirst(), p.getSecond().getName());
- found = true;
- }
- }
-
- if(!found){
- throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " +
- c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
- }
- }
- }
-
-
- /**
- * Get the legacy JSON mapper for the specified class.
- *
- * NOTE: This is intended for internal backward-compatibility use.
- *
- * Note to developers: The following JSON mappers are for handling legacy format JSON.
- * Note that after 1.0.0-alpha, the JSON subtype format for Transforms, Filters, Conditions etc were changed from
- * a wrapper object, to an "@class" field. However, to not break all saved transforms networks, these mappers are
- * part of the solution.
- *
- * How legacy loading works (same pattern for all types - Transform, Filter, Condition etc)
- * 1. Transforms etc JSON that has a "@class" field are deserialized as normal
- * 2. Transforms JSON that don't have such a field are mapped (via Layer @JsonTypeInfo) to LegacyMappingHelper.TransformHelper
- * 3. LegacyMappingHelper.TransformHelper has a @JsonDeserialize annotation - we use LegacyMappingHelper.LegacyTransformDeserializer to handle it
- * 4. LegacyTransformDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names
- * 5. BaseLegacyDeserializer (that LegacyTransformDeserializer extends) does a lookup and handles the deserialization
- *
- * Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format,
- * as it'll fail due to not having the expected "@class" annotation.
- * Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified
- * class anyway. The ignoring is done via an annotation introspector, defined below in this class.
- * However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of
- * all types - if we did, then any nested types would fail (i.e., an Condition in a Transform - the Transform couldn't
- * be deserialized correctly, as the annotation would be ignored).
- *
- */
- public static synchronized ObjectMapper getLegacyMapperFor(@NonNull Class> clazz){
- if(!legacyMappers.containsKey(clazz)){
- ObjectMapper m = new ObjectMapper();
- configureMapper(m);
- m.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(clazz)));
- legacyMappers.put(clazz, m);
- }
- return legacyMappers.get(clazz);
+ return legacyMapper;
}
/**
@@ -237,61 +76,7 @@ public class JsonMappers {
ret.enable(SerializationFeature.INDENT_OUTPUT);
ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
+ ret.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); //Need this otherwise JsonProperty annotations on constructors won't be seen
}
-
- /**
- * Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc.
- * This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring
- * a set of JsonTypeInfo annotations
- */
- @AllArgsConstructor
- private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector {
-
- private List classList;
-
- @Override
- protected TypeResolverBuilder> _findTypeResolver(MapperConfig> config, Annotated ann, JavaType baseType) {
- if(ann instanceof AnnotatedClass){
- AnnotatedClass c = (AnnotatedClass)ann;
- Class> annClass = c.getAnnotated();
-
- boolean isAssignable = false;
- for(Class c2 : classList){
- if(c2.isAssignableFrom(annClass)){
- isAssignable = true;
- break;
- }
- }
-
- if( isAssignable ){
- AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations();
- if(annotations == null || annotations.annotations() == null){
- //Probably not necessary - but here for safety
- return super._findTypeResolver(config, ann, baseType);
- }
-
- AnnotationMap newMap = null;
- for(Annotation a : annotations.annotations()){
- Class> annType = a.annotationType();
- if(annType == JsonTypeInfo.class){
- //Ignore the JsonTypeInfo annotation on the Layer class
- continue;
- }
- if(newMap == null){
- newMap = new AnnotationMap();
- }
- newMap.add(a);
- }
- if(newMap == null)
- return null;
-
- //Pass the remaining annotations (if any) to the original introspector
- AnnotatedClass ann2 = c.withAnnotations(newMap);
- return super._findTypeResolver(config, ann2, baseType);
- }
- }
- return super._findTypeResolver(config, ann, baseType);
- }
- }
}
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java
deleted file mode 100644
index 5a9b48a7c..000000000
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.api.transform.serde.legacy;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.Getter;
-import org.datavec.api.transform.serde.JsonMappers;
-import org.nd4j.serde.json.BaseLegacyDeserializer;
-import org.nd4j.shade.jackson.databind.ObjectMapper;
-
-import java.util.Map;
-
-@AllArgsConstructor
-@Data
-public class GenericLegacyDeserializer extends BaseLegacyDeserializer {
-
- @Getter
- protected final Class deserializedType;
- @Getter
- protected final Map legacyNamesMap;
-
- @Override
- public ObjectMapper getLegacyJsonMapper() {
- return JsonMappers.getLegacyMapperFor(getDeserializedType());
- }
-}
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java
new file mode 100644
index 000000000..8df741a49
--- /dev/null
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java
@@ -0,0 +1,267 @@
+package org.datavec.api.transform.serde.legacy;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import org.datavec.api.transform.Transform;
+import org.datavec.api.transform.analysis.columns.*;
+import org.datavec.api.transform.condition.BooleanCondition;
+import org.datavec.api.transform.condition.Condition;
+import org.datavec.api.transform.condition.column.*;
+import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
+import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
+import org.datavec.api.transform.filter.ConditionFilter;
+import org.datavec.api.transform.filter.Filter;
+import org.datavec.api.transform.filter.FilterInvalidValues;
+import org.datavec.api.transform.filter.InvalidNumColumns;
+import org.datavec.api.transform.metadata.*;
+import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
+import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
+import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
+import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
+import org.datavec.api.transform.rank.CalculateSortedRank;
+import org.datavec.api.transform.schema.Schema;
+import org.datavec.api.transform.schema.SequenceSchema;
+import org.datavec.api.transform.sequence.ReduceSequenceTransform;
+import org.datavec.api.transform.sequence.SequenceComparator;
+import org.datavec.api.transform.sequence.SequenceSplit;
+import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
+import org.datavec.api.transform.sequence.comparator.StringComparator;
+import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
+import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
+import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
+import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
+import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
+import org.datavec.api.transform.sequence.window.TimeWindowFunction;
+import org.datavec.api.transform.sequence.window.WindowFunction;
+import org.datavec.api.transform.stringreduce.IStringReducer;
+import org.datavec.api.transform.stringreduce.StringReducer;
+import org.datavec.api.transform.transform.categorical.*;
+import org.datavec.api.transform.transform.column.*;
+import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
+import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
+import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
+import org.datavec.api.transform.transform.doubletransform.*;
+import org.datavec.api.transform.transform.integer.*;
+import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
+import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
+import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
+import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
+import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
+import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
+import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
+import org.datavec.api.transform.transform.string.*;
+import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
+import org.datavec.api.transform.transform.time.StringToTimeTransform;
+import org.datavec.api.transform.transform.time.TimeMathOpTransform;
+import org.datavec.api.writable.*;
+import org.datavec.api.writable.comparator.*;
+import org.nd4j.shade.jackson.annotation.JsonInclude;
+import org.nd4j.shade.jackson.annotation.JsonSubTypes;
+import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
+import org.nd4j.shade.jackson.databind.ObjectMapper;
+
+/**
+ * This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override
+ * the existing annotations.
+ * In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field".
+ * We use these mixins to allow us to still load the old format
+ *
+ * @author Alex Black
+ */
+public class LegacyJsonFormat {
+
+ private LegacyJsonFormat(){ }
+
+ /**
+ * Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before
+ * @return Object mapper
+ */
+ public static ObjectMapper legacyMapper(){
+ ObjectMapper om = new ObjectMapper();
+ om.addMixIn(Schema.class, SchemaMixin.class);
+ om.addMixIn(ColumnMetaData.class, ColumnMetaDataMixin.class);
+ om.addMixIn(Transform.class, TransformMixin.class);
+ om.addMixIn(Condition.class, ConditionMixin.class);
+ om.addMixIn(Writable.class, WritableMixin.class);
+ om.addMixIn(Filter.class, FilterMixin.class);
+ om.addMixIn(SequenceComparator.class, SequenceComparatorMixin.class);
+ om.addMixIn(SequenceSplit.class, SequenceSplitMixin.class);
+ om.addMixIn(WindowFunction.class, WindowFunctionMixin.class);
+ om.addMixIn(CalculateSortedRank.class, CalculateSortedRankMixin.class);
+ om.addMixIn(WritableComparator.class, WritableComparatorMixin.class);
+ om.addMixIn(ColumnAnalysis.class, ColumnAnalysisMixin.class);
+ om.addMixIn(IStringReducer.class, IStringReducerMixin.class);
+ return om;
+ }
+
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes({@JsonSubTypes.Type(value = Schema.class, name = "Schema"),
+ @JsonSubTypes.Type(value = SequenceSchema.class, name = "SequenceSchema")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class SchemaMixin { }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes({@JsonSubTypes.Type(value = BinaryMetaData.class, name = "Binary"),
+ @JsonSubTypes.Type(value = BooleanMetaData.class, name = "Boloean"),
+ @JsonSubTypes.Type(value = CategoricalMetaData.class, name = "Categorical"),
+ @JsonSubTypes.Type(value = DoubleMetaData.class, name = "Double"),
+ @JsonSubTypes.Type(value = FloatMetaData.class, name = "Float"),
+ @JsonSubTypes.Type(value = IntegerMetaData.class, name = "Integer"),
+ @JsonSubTypes.Type(value = LongMetaData.class, name = "Long"),
+ @JsonSubTypes.Type(value = NDArrayMetaData.class, name = "NDArray"),
+ @JsonSubTypes.Type(value = StringMetaData.class, name = "String"),
+ @JsonSubTypes.Type(value = TimeMetaData.class, name = "Time")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class ColumnMetaDataMixin { }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank"),
+ @JsonSubTypes.Type(value = CategoricalToIntegerTransform.class, name = "CategoricalToIntegerTransform"),
+ @JsonSubTypes.Type(value = CategoricalToOneHotTransform.class, name = "CategoricalToOneHotTransform"),
+ @JsonSubTypes.Type(value = IntegerToCategoricalTransform.class, name = "IntegerToCategoricalTransform"),
+ @JsonSubTypes.Type(value = StringToCategoricalTransform.class, name = "StringToCategoricalTransform"),
+ @JsonSubTypes.Type(value = DuplicateColumnsTransform.class, name = "DuplicateColumnsTransform"),
+ @JsonSubTypes.Type(value = RemoveColumnsTransform.class, name = "RemoveColumnsTransform"),
+ @JsonSubTypes.Type(value = RenameColumnsTransform.class, name = "RenameColumnsTransform"),
+ @JsonSubTypes.Type(value = ReorderColumnsTransform.class, name = "ReorderColumnsTransform"),
+ @JsonSubTypes.Type(value = ConditionalCopyValueTransform.class, name = "ConditionalCopyValueTransform"),
+ @JsonSubTypes.Type(value = ConditionalReplaceValueTransform.class, name = "ConditionalReplaceValueTransform"),
+ @JsonSubTypes.Type(value = ConditionalReplaceValueTransformWithDefault.class, name = "ConditionalReplaceValueTransformWithDefault"),
+ @JsonSubTypes.Type(value = DoubleColumnsMathOpTransform.class, name = "DoubleColumnsMathOpTransform"),
+ @JsonSubTypes.Type(value = DoubleMathOpTransform.class, name = "DoubleMathOpTransform"),
+ @JsonSubTypes.Type(value = Log2Normalizer.class, name = "Log2Normalizer"),
+ @JsonSubTypes.Type(value = MinMaxNormalizer.class, name = "MinMaxNormalizer"),
+ @JsonSubTypes.Type(value = StandardizeNormalizer.class, name = "StandardizeNormalizer"),
+ @JsonSubTypes.Type(value = SubtractMeanNormalizer.class, name = "SubtractMeanNormalizer"),
+ @JsonSubTypes.Type(value = IntegerColumnsMathOpTransform.class, name = "IntegerColumnsMathOpTransform"),
+ @JsonSubTypes.Type(value = IntegerMathOpTransform.class, name = "IntegerMathOpTransform"),
+ @JsonSubTypes.Type(value = ReplaceEmptyIntegerWithValueTransform.class, name = "ReplaceEmptyIntegerWithValueTransform"),
+ @JsonSubTypes.Type(value = ReplaceInvalidWithIntegerTransform.class, name = "ReplaceInvalidWithIntegerTransform"),
+ @JsonSubTypes.Type(value = LongColumnsMathOpTransform.class, name = "LongColumnsMathOpTransform"),
+ @JsonSubTypes.Type(value = LongMathOpTransform.class, name = "LongMathOpTransform"),
+ @JsonSubTypes.Type(value = MapAllStringsExceptListTransform.class, name = "MapAllStringsExceptListTransform"),
+ @JsonSubTypes.Type(value = RemoveWhiteSpaceTransform.class, name = "RemoveWhiteSpaceTransform"),
+ @JsonSubTypes.Type(value = ReplaceEmptyStringTransform.class, name = "ReplaceEmptyStringTransform"),
+ @JsonSubTypes.Type(value = ReplaceStringTransform.class, name = "ReplaceStringTransform"),
+ @JsonSubTypes.Type(value = StringListToCategoricalSetTransform.class, name = "StringListToCategoricalSetTransform"),
+ @JsonSubTypes.Type(value = StringMapTransform.class, name = "StringMapTransform"),
+ @JsonSubTypes.Type(value = DeriveColumnsFromTimeTransform.class, name = "DeriveColumnsFromTimeTransform"),
+ @JsonSubTypes.Type(value = StringToTimeTransform.class, name = "StringToTimeTransform"),
+ @JsonSubTypes.Type(value = TimeMathOpTransform.class, name = "TimeMathOpTransform"),
+ @JsonSubTypes.Type(value = ReduceSequenceByWindowTransform.class, name = "ReduceSequenceByWindowTransform"),
+ @JsonSubTypes.Type(value = DoubleMathFunctionTransform.class, name = "DoubleMathFunctionTransform"),
+ @JsonSubTypes.Type(value = AddConstantColumnTransform.class, name = "AddConstantColumnTransform"),
+ @JsonSubTypes.Type(value = RemoveAllColumnsExceptForTransform.class, name = "RemoveAllColumnsExceptForTransform"),
+ @JsonSubTypes.Type(value = ParseDoubleTransform.class, name = "ParseDoubleTransform"),
+ @JsonSubTypes.Type(value = ConvertToString.class, name = "ConvertToStringTransform"),
+ @JsonSubTypes.Type(value = AppendStringColumnTransform.class, name = "AppendStringColumnTransform"),
+ @JsonSubTypes.Type(value = SequenceDifferenceTransform.class, name = "SequenceDifferenceTransform"),
+ @JsonSubTypes.Type(value = ReduceSequenceTransform.class, name = "ReduceSequenceTransform"),
+ @JsonSubTypes.Type(value = SequenceMovingWindowReduceTransform.class, name = "SequenceMovingWindowReduceTransform"),
+ @JsonSubTypes.Type(value = IntegerToOneHotTransform.class, name = "IntegerToOneHotTransform"),
+ @JsonSubTypes.Type(value = SequenceTrimTransform.class, name = "SequenceTrimTransform"),
+ @JsonSubTypes.Type(value = SequenceOffsetTransform.class, name = "SequenceOffsetTransform"),
+ @JsonSubTypes.Type(value = NDArrayColumnsMathOpTransform.class, name = "NDArrayColumnsMathOpTransform"),
+ @JsonSubTypes.Type(value = NDArrayDistanceTransform.class, name = "NDArrayDistanceTransform"),
+ @JsonSubTypes.Type(value = NDArrayMathFunctionTransform.class, name = "NDArrayMathFunctionTransform"),
+ @JsonSubTypes.Type(value = NDArrayScalarOpTransform.class, name = "NDArrayScalarOpTransform"),
+ @JsonSubTypes.Type(value = ChangeCaseStringTransform.class, name = "ChangeCaseStringTransform"),
+ @JsonSubTypes.Type(value = ConcatenateStringColumns.class, name = "ConcatenateStringColumns"),
+ @JsonSubTypes.Type(value = StringListToCountsNDArrayTransform.class, name = "StringListToCountsNDArrayTransform"),
+ @JsonSubTypes.Type(value = StringListToIndicesNDArrayTransform.class, name = "StringListToIndicesNDArrayTransform"),
+ @JsonSubTypes.Type(value = PivotTransform.class, name = "PivotTransform"),
+ @JsonSubTypes.Type(value = TextToCharacterIndexTransform.class, name = "TextToCharacterIndexTransform")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class TransformMixin { }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = TrivialColumnCondition.class, name = "TrivialColumnCondition"),
+ @JsonSubTypes.Type(value = CategoricalColumnCondition.class, name = "CategoricalColumnCondition"),
+ @JsonSubTypes.Type(value = DoubleColumnCondition.class, name = "DoubleColumnCondition"),
+ @JsonSubTypes.Type(value = IntegerColumnCondition.class, name = "IntegerColumnCondition"),
+ @JsonSubTypes.Type(value = LongColumnCondition.class, name = "LongColumnCondition"),
+ @JsonSubTypes.Type(value = NullWritableColumnCondition.class, name = "NullWritableColumnCondition"),
+ @JsonSubTypes.Type(value = StringColumnCondition.class, name = "StringColumnCondition"),
+ @JsonSubTypes.Type(value = TimeColumnCondition.class, name = "TimeColumnCondition"),
+ @JsonSubTypes.Type(value = StringRegexColumnCondition.class, name = "StringRegexColumnCondition"),
+ @JsonSubTypes.Type(value = BooleanCondition.class, name = "BooleanCondition"),
+ @JsonSubTypes.Type(value = NaNColumnCondition.class, name = "NaNColumnCondition"),
+ @JsonSubTypes.Type(value = InfiniteColumnCondition.class, name = "InfiniteColumnCondition"),
+ @JsonSubTypes.Type(value = SequenceLengthCondition.class, name = "SequenceLengthCondition")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class ConditionMixin { }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = ArrayWritable.class, name = "ArrayWritable"),
+ @JsonSubTypes.Type(value = BooleanWritable.class, name = "BooleanWritable"),
+ @JsonSubTypes.Type(value = ByteWritable.class, name = "ByteWritable"),
+ @JsonSubTypes.Type(value = DoubleWritable.class, name = "DoubleWritable"),
+ @JsonSubTypes.Type(value = FloatWritable.class, name = "FloatWritable"),
+ @JsonSubTypes.Type(value = IntWritable.class, name = "IntWritable"),
+ @JsonSubTypes.Type(value = LongWritable.class, name = "LongWritable"),
+ @JsonSubTypes.Type(value = NullWritable.class, name = "NullWritable"),
+ @JsonSubTypes.Type(value = Text.class, name = "Text"),
+ @JsonSubTypes.Type(value = BytesWritable.class, name = "BytesWritable")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class WritableMixin { }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = ConditionFilter.class, name = "ConditionFilter"),
+ @JsonSubTypes.Type(value = FilterInvalidValues.class, name = "FilterInvalidValues"),
+ @JsonSubTypes.Type(value = InvalidNumColumns.class, name = "InvalidNumCols")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class FilterMixin { }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = NumericalColumnComparator.class, name = "NumericalColumnComparator"),
+ @JsonSubTypes.Type(value = StringComparator.class, name = "StringComparator")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class SequenceComparatorMixin { }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = SequenceSplitTimeSeparation.class, name = "SequenceSplitTimeSeparation"),
+ @JsonSubTypes.Type(value = SplitMaxLengthSequence.class, name = "SplitMaxLengthSequence")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class SequenceSplitMixin { }
+
+ @JsonInclude(JsonInclude.Include.NON_NULL)
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = TimeWindowFunction.class, name = "TimeWindowFunction"),
+ @JsonSubTypes.Type(value = OverlappingTimeWindowFunction.class, name = "OverlappingTimeWindowFunction")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class WindowFunctionMixin { }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class CalculateSortedRankMixin { }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = DoubleWritableComparator.class, name = "DoubleWritableComparator"),
+ @JsonSubTypes.Type(value = FloatWritableComparator.class, name = "FloatWritableComparator"),
+ @JsonSubTypes.Type(value = IntWritableComparator.class, name = "IntWritableComparator"),
+ @JsonSubTypes.Type(value = LongWritableComparator.class, name = "LongWritableComparator"),
+ @JsonSubTypes.Type(value = TextWritableComparator.class, name = "TextWritableComparator")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class WritableComparatorMixin { }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = BytesAnalysis.class, name = "BytesAnalysis"),
+ @JsonSubTypes.Type(value = CategoricalAnalysis.class, name = "CategoricalAnalysis"),
+ @JsonSubTypes.Type(value = DoubleAnalysis.class, name = "DoubleAnalysis"),
+ @JsonSubTypes.Type(value = IntegerAnalysis.class, name = "IntegerAnalysis"),
+ @JsonSubTypes.Type(value = LongAnalysis.class, name = "LongAnalysis"),
+ @JsonSubTypes.Type(value = StringAnalysis.class, name = "StringAnalysis"),
+ @JsonSubTypes.Type(value = TimeAnalysis.class, name = "TimeAnalysis")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class ColumnAnalysisMixin{ }
+
+ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
+ @JsonSubTypes(value = {@JsonSubTypes.Type(value = StringReducer.class, name = "StringReducer")})
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static class IStringReducerMixin{ }
+}
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java
deleted file mode 100644
index c4b478278..000000000
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java
+++ /dev/null
@@ -1,535 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.api.transform.serde.legacy;
-
-import org.datavec.api.transform.Transform;
-import org.datavec.api.transform.analysis.columns.*;
-import org.datavec.api.transform.condition.BooleanCondition;
-import org.datavec.api.transform.condition.Condition;
-import org.datavec.api.transform.condition.column.*;
-import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
-import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
-import org.datavec.api.transform.filter.ConditionFilter;
-import org.datavec.api.transform.filter.Filter;
-import org.datavec.api.transform.filter.FilterInvalidValues;
-import org.datavec.api.transform.filter.InvalidNumColumns;
-import org.datavec.api.transform.metadata.*;
-import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
-import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
-import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
-import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
-import org.datavec.api.transform.rank.CalculateSortedRank;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.schema.SequenceSchema;
-import org.datavec.api.transform.sequence.ReduceSequenceTransform;
-import org.datavec.api.transform.sequence.SequenceComparator;
-import org.datavec.api.transform.sequence.SequenceSplit;
-import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
-import org.datavec.api.transform.sequence.comparator.StringComparator;
-import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
-import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
-import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
-import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
-import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
-import org.datavec.api.transform.sequence.window.TimeWindowFunction;
-import org.datavec.api.transform.sequence.window.WindowFunction;
-import org.datavec.api.transform.stringreduce.IStringReducer;
-import org.datavec.api.transform.stringreduce.StringReducer;
-import org.datavec.api.transform.transform.categorical.*;
-import org.datavec.api.transform.transform.column.*;
-import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
-import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
-import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
-import org.datavec.api.transform.transform.doubletransform.*;
-import org.datavec.api.transform.transform.integer.*;
-import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
-import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
-import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
-import org.datavec.api.transform.transform.nlp.TextToTermIndexSequenceTransform;
-import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
-import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
-import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
-import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
-import org.datavec.api.transform.transform.string.*;
-import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
-import org.datavec.api.transform.transform.time.StringToTimeTransform;
-import org.datavec.api.transform.transform.time.TimeMathOpTransform;
-import org.datavec.api.writable.*;
-import org.datavec.api.writable.comparator.*;
-import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
-
-import java.util.HashMap;
-import java.util.Map;
-
-public class LegacyMappingHelper {
-
- public static Map legacyMappingForClass(Class c){
- //Need to be able to get the map - and they need to be mutable...
- switch (c.getSimpleName()){
- case "Transform":
- return getLegacyMappingImageTransform();
- case "ColumnAnalysis":
- return getLegacyMappingColumnAnalysis();
- case "Condition":
- return getLegacyMappingCondition();
- case "Filter":
- return getLegacyMappingFilter();
- case "ColumnMetaData":
- return mapColumnMetaData;
- case "CalculateSortedRank":
- return mapCalculateSortedRank;
- case "Schema":
- return mapSchema;
- case "SequenceComparator":
- return mapSequenceComparator;
- case "SequenceSplit":
- return mapSequenceSplit;
- case "WindowFunction":
- return mapWindowFunction;
- case "IStringReducer":
- return mapIStringReducer;
- case "Writable":
- return mapWritable;
- case "WritableComparator":
- return mapWritableComparator;
- case "ImageTransform":
- return mapImageTransform;
- default:
- //Should never happen
- throw new IllegalArgumentException("No legacy mapping available for class " + c.getName());
- }
- }
-
- private static Map mapTransform;
- private static Map mapColumnAnalysis;
- private static Map mapCondition;
- private static Map mapFilter;
- private static Map mapColumnMetaData;
- private static Map mapCalculateSortedRank;
- private static Map mapSchema;
- private static Map mapSequenceComparator;
- private static Map mapSequenceSplit;
- private static Map mapWindowFunction;
- private static Map mapIStringReducer;
- private static Map mapWritable;
- private static Map mapWritableComparator;
- private static Map mapImageTransform;
-
- private static synchronized Map getLegacyMappingTransform(){
-
- if(mapTransform == null) {
- //The following classes all used their class short name
- Map m = new HashMap<>();
- m.put("CategoricalToIntegerTransform", CategoricalToIntegerTransform.class.getName());
- m.put("CategoricalToOneHotTransform", CategoricalToOneHotTransform.class.getName());
- m.put("IntegerToCategoricalTransform", IntegerToCategoricalTransform.class.getName());
- m.put("StringToCategoricalTransform", StringToCategoricalTransform.class.getName());
- m.put("DuplicateColumnsTransform", DuplicateColumnsTransform.class.getName());
- m.put("RemoveColumnsTransform", RemoveColumnsTransform.class.getName());
- m.put("RenameColumnsTransform", RenameColumnsTransform.class.getName());
- m.put("ReorderColumnsTransform", ReorderColumnsTransform.class.getName());
- m.put("ConditionalCopyValueTransform", ConditionalCopyValueTransform.class.getName());
- m.put("ConditionalReplaceValueTransform", ConditionalReplaceValueTransform.class.getName());
- m.put("ConditionalReplaceValueTransformWithDefault", ConditionalReplaceValueTransformWithDefault.class.getName());
- m.put("DoubleColumnsMathOpTransform", DoubleColumnsMathOpTransform.class.getName());
- m.put("DoubleMathOpTransform", DoubleMathOpTransform.class.getName());
- m.put("Log2Normalizer", Log2Normalizer.class.getName());
- m.put("MinMaxNormalizer", MinMaxNormalizer.class.getName());
- m.put("StandardizeNormalizer", StandardizeNormalizer.class.getName());
- m.put("SubtractMeanNormalizer", SubtractMeanNormalizer.class.getName());
- m.put("IntegerColumnsMathOpTransform", IntegerColumnsMathOpTransform.class.getName());
- m.put("IntegerMathOpTransform", IntegerMathOpTransform.class.getName());
- m.put("ReplaceEmptyIntegerWithValueTransform", ReplaceEmptyIntegerWithValueTransform.class.getName());
- m.put("ReplaceInvalidWithIntegerTransform", ReplaceInvalidWithIntegerTransform.class.getName());
- m.put("LongColumnsMathOpTransform", LongColumnsMathOpTransform.class.getName());
- m.put("LongMathOpTransform", LongMathOpTransform.class.getName());
- m.put("MapAllStringsExceptListTransform", MapAllStringsExceptListTransform.class.getName());
- m.put("RemoveWhiteSpaceTransform", RemoveWhiteSpaceTransform.class.getName());
- m.put("ReplaceEmptyStringTransform", ReplaceEmptyStringTransform.class.getName());
- m.put("ReplaceStringTransform", ReplaceStringTransform.class.getName());
- m.put("StringListToCategoricalSetTransform", StringListToCategoricalSetTransform.class.getName());
- m.put("StringMapTransform", StringMapTransform.class.getName());
- m.put("DeriveColumnsFromTimeTransform", DeriveColumnsFromTimeTransform.class.getName());
- m.put("StringToTimeTransform", StringToTimeTransform.class.getName());
- m.put("TimeMathOpTransform", TimeMathOpTransform.class.getName());
- m.put("ReduceSequenceByWindowTransform", ReduceSequenceByWindowTransform.class.getName());
- m.put("DoubleMathFunctionTransform", DoubleMathFunctionTransform.class.getName());
- m.put("AddConstantColumnTransform", AddConstantColumnTransform.class.getName());
- m.put("RemoveAllColumnsExceptForTransform", RemoveAllColumnsExceptForTransform.class.getName());
- m.put("ParseDoubleTransform", ParseDoubleTransform.class.getName());
- m.put("ConvertToStringTransform", ConvertToString.class.getName());
- m.put("AppendStringColumnTransform", AppendStringColumnTransform.class.getName());
- m.put("SequenceDifferenceTransform", SequenceDifferenceTransform.class.getName());
- m.put("ReduceSequenceTransform", ReduceSequenceTransform.class.getName());
- m.put("SequenceMovingWindowReduceTransform", SequenceMovingWindowReduceTransform.class.getName());
- m.put("IntegerToOneHotTransform", IntegerToOneHotTransform.class.getName());
- m.put("SequenceTrimTransform", SequenceTrimTransform.class.getName());
- m.put("SequenceOffsetTransform", SequenceOffsetTransform.class.getName());
- m.put("NDArrayColumnsMathOpTransform", NDArrayColumnsMathOpTransform.class.getName());
- m.put("NDArrayDistanceTransform", NDArrayDistanceTransform.class.getName());
- m.put("NDArrayMathFunctionTransform", NDArrayMathFunctionTransform.class.getName());
- m.put("NDArrayScalarOpTransform", NDArrayScalarOpTransform.class.getName());
- m.put("ChangeCaseStringTransform", ChangeCaseStringTransform.class.getName());
- m.put("ConcatenateStringColumns", ConcatenateStringColumns.class.getName());
- m.put("StringListToCountsNDArrayTransform", StringListToCountsNDArrayTransform.class.getName());
- m.put("StringListToIndicesNDArrayTransform", StringListToIndicesNDArrayTransform.class.getName());
- m.put("PivotTransform", PivotTransform.class.getName());
- m.put("TextToCharacterIndexTransform", TextToCharacterIndexTransform.class.getName());
-
- //The following never had subtype annotations, and hence will have had the default name:
- m.put(TextToTermIndexSequenceTransform.class.getSimpleName(), TextToTermIndexSequenceTransform.class.getName());
- m.put(ConvertToInteger.class.getSimpleName(), ConvertToInteger.class.getName());
- m.put(ConvertToDouble.class.getSimpleName(), ConvertToDouble.class.getName());
-
- mapTransform = m;
- }
-
- return mapTransform;
- }
-
- private static Map getLegacyMappingColumnAnalysis(){
- if(mapColumnAnalysis == null) {
- Map m = new HashMap<>();
- m.put("BytesAnalysis", BytesAnalysis.class.getName());
- m.put("CategoricalAnalysis", CategoricalAnalysis.class.getName());
- m.put("DoubleAnalysis", DoubleAnalysis.class.getName());
- m.put("IntegerAnalysis", IntegerAnalysis.class.getName());
- m.put("LongAnalysis", LongAnalysis.class.getName());
- m.put("StringAnalysis", StringAnalysis.class.getName());
- m.put("TimeAnalysis", TimeAnalysis.class.getName());
-
- //The following never had subtype annotations, and hence will have had the default name:
- m.put(NDArrayAnalysis.class.getSimpleName(), NDArrayAnalysis.class.getName());
-
- mapColumnAnalysis = m;
- }
-
- return mapColumnAnalysis;
- }
-
- private static Map getLegacyMappingCondition(){
- if(mapCondition == null) {
- Map m = new HashMap<>();
- m.put("TrivialColumnCondition", TrivialColumnCondition.class.getName());
- m.put("CategoricalColumnCondition", CategoricalColumnCondition.class.getName());
- m.put("DoubleColumnCondition", DoubleColumnCondition.class.getName());
- m.put("IntegerColumnCondition", IntegerColumnCondition.class.getName());
- m.put("LongColumnCondition", LongColumnCondition.class.getName());
- m.put("NullWritableColumnCondition", NullWritableColumnCondition.class.getName());
- m.put("StringColumnCondition", StringColumnCondition.class.getName());
- m.put("TimeColumnCondition", TimeColumnCondition.class.getName());
- m.put("StringRegexColumnCondition", StringRegexColumnCondition.class.getName());
- m.put("BooleanCondition", BooleanCondition.class.getName());
- m.put("NaNColumnCondition", NaNColumnCondition.class.getName());
- m.put("InfiniteColumnCondition", InfiniteColumnCondition.class.getName());
- m.put("SequenceLengthCondition", SequenceLengthCondition.class.getName());
-
- //The following never had subtype annotations, and hence will have had the default name:
- m.put(InvalidValueColumnCondition.class.getSimpleName(), InvalidValueColumnCondition.class.getName());
- m.put(BooleanColumnCondition.class.getSimpleName(), BooleanColumnCondition.class.getName());
-
- mapCondition = m;
- }
-
- return mapCondition;
- }
-
- private static Map getLegacyMappingFilter(){
- if(mapFilter == null) {
- Map m = new HashMap<>();
- m.put("ConditionFilter", ConditionFilter.class.getName());
- m.put("FilterInvalidValues", FilterInvalidValues.class.getName());
- m.put("InvalidNumCols", InvalidNumColumns.class.getName());
-
- mapFilter = m;
- }
- return mapFilter;
- }
-
- private static Map getLegacyMappingColumnMetaData(){
- if(mapColumnMetaData == null) {
- Map m = new HashMap<>();
- m.put("Categorical", CategoricalMetaData.class.getName());
- m.put("Double", DoubleMetaData.class.getName());
- m.put("Float", FloatMetaData.class.getName());
- m.put("Integer", IntegerMetaData.class.getName());
- m.put("Long", LongMetaData.class.getName());
- m.put("String", StringMetaData.class.getName());
- m.put("Time", TimeMetaData.class.getName());
- m.put("NDArray", NDArrayMetaData.class.getName());
-
- //The following never had subtype annotations, and hence will have had the default name:
- m.put(BooleanMetaData.class.getSimpleName(), BooleanMetaData.class.getName());
- m.put(BinaryMetaData.class.getSimpleName(), BinaryMetaData.class.getName());
-
- mapColumnMetaData = m;
- }
-
- return mapColumnMetaData;
- }
-
- private static Map getLegacyMappingCalculateSortedRank(){
- if(mapCalculateSortedRank == null) {
- Map m = new HashMap<>();
- m.put("CalculateSortedRank", CalculateSortedRank.class.getName());
- mapCalculateSortedRank = m;
- }
- return mapCalculateSortedRank;
- }
-
- private static Map getLegacyMappingSchema(){
- if(mapSchema == null) {
- Map m = new HashMap<>();
- m.put("Schema", Schema.class.getName());
- m.put("SequenceSchema", SequenceSchema.class.getName());
-
- mapSchema = m;
- }
- return mapSchema;
- }
-
- private static Map getLegacyMappingSequenceComparator(){
- if(mapSequenceComparator == null) {
- Map m = new HashMap<>();
- m.put("NumericalColumnComparator", NumericalColumnComparator.class.getName());
- m.put("StringComparator", StringComparator.class.getName());
-
- mapSequenceComparator = m;
- }
- return mapSequenceComparator;
- }
-
- private static Map getLegacyMappingSequenceSplit(){
- if(mapSequenceSplit == null) {
- Map m = new HashMap<>();
- m.put("SequenceSplitTimeSeparation", SequenceSplitTimeSeparation.class.getName());
- m.put("SplitMaxLengthSequence", SplitMaxLengthSequence.class.getName());
-
- mapSequenceSplit = m;
- }
- return mapSequenceSplit;
- }
-
- private static Map getLegacyMappingWindowFunction(){
- if(mapWindowFunction == null) {
- Map m = new HashMap<>();
- m.put("TimeWindowFunction", TimeWindowFunction.class.getName());
- m.put("OverlappingTimeWindowFunction", OverlappingTimeWindowFunction.class.getName());
-
- mapWindowFunction = m;
- }
- return mapWindowFunction;
- }
-
- private static Map getLegacyMappingIStringReducer(){
- if(mapIStringReducer == null) {
- Map m = new HashMap<>();
- m.put("StringReducer", StringReducer.class.getName());
-
- mapIStringReducer = m;
- }
- return mapIStringReducer;
- }
-
- private static Map getLegacyMappingWritable(){
- if (mapWritable == null) {
- Map m = new HashMap<>();
- m.put("ArrayWritable", ArrayWritable.class.getName());
- m.put("BooleanWritable", BooleanWritable.class.getName());
- m.put("ByteWritable", ByteWritable.class.getName());
- m.put("DoubleWritable", DoubleWritable.class.getName());
- m.put("FloatWritable", FloatWritable.class.getName());
- m.put("IntWritable", IntWritable.class.getName());
- m.put("LongWritable", LongWritable.class.getName());
- m.put("NullWritable", NullWritable.class.getName());
- m.put("Text", Text.class.getName());
- m.put("BytesWritable", BytesWritable.class.getName());
-
- //The following never had subtype annotations, and hence will have had the default name:
- m.put(NDArrayWritable.class.getSimpleName(), NDArrayWritable.class.getName());
-
- mapWritable = m;
- }
-
- return mapWritable;
- }
-
- private static Map getLegacyMappingWritableComparator(){
- if(mapWritableComparator == null) {
- Map m = new HashMap<>();
- m.put("DoubleWritableComparator", DoubleWritableComparator.class.getName());
- m.put("FloatWritableComparator", FloatWritableComparator.class.getName());
- m.put("IntWritableComparator", IntWritableComparator.class.getName());
- m.put("LongWritableComparator", LongWritableComparator.class.getName());
- m.put("TextWritableComparator", TextWritableComparator.class.getName());
-
- //The following never had subtype annotations, and hence will have had the default name:
- m.put(ByteWritable.Comparator.class.getSimpleName(), ByteWritable.Comparator.class.getName());
- m.put(FloatWritable.Comparator.class.getSimpleName(), FloatWritable.Comparator.class.getName());
- m.put(IntWritable.Comparator.class.getSimpleName(), IntWritable.Comparator.class.getName());
- m.put(BooleanWritable.Comparator.class.getSimpleName(), BooleanWritable.Comparator.class.getName());
- m.put(LongWritable.Comparator.class.getSimpleName(), LongWritable.Comparator.class.getName());
- m.put(Text.Comparator.class.getSimpleName(), Text.Comparator.class.getName());
- m.put(LongWritable.DecreasingComparator.class.getSimpleName(), LongWritable.DecreasingComparator.class.getName());
- m.put(DoubleWritable.Comparator.class.getSimpleName(), DoubleWritable.Comparator.class.getName());
-
- mapWritableComparator = m;
- }
-
- return mapWritableComparator;
- }
-
- public static Map getLegacyMappingImageTransform(){
- if(mapImageTransform == null) {
- Map m = new HashMap<>();
- m.put("EqualizeHistTransform", "org.datavec.image.transform.EqualizeHistTransform");
- m.put("RotateImageTransform", "org.datavec.image.transform.RotateImageTransform");
- m.put("ColorConversionTransform", "org.datavec.image.transform.ColorConversionTransform");
- m.put("WarpImageTransform", "org.datavec.image.transform.WarpImageTransform");
- m.put("BoxImageTransform", "org.datavec.image.transform.BoxImageTransform");
- m.put("CropImageTransform", "org.datavec.image.transform.CropImageTransform");
- m.put("FilterImageTransform", "org.datavec.image.transform.FilterImageTransform");
- m.put("FlipImageTransform", "org.datavec.image.transform.FlipImageTransform");
- m.put("LargestBlobCropTransform", "org.datavec.image.transform.LargestBlobCropTransform");
- m.put("ResizeImageTransform", "org.datavec.image.transform.ResizeImageTransform");
- m.put("RandomCropTransform", "org.datavec.image.transform.RandomCropTransform");
- m.put("ScaleImageTransform", "org.datavec.image.transform.ScaleImageTransform");
-
- mapImageTransform = m;
- }
- return mapImageTransform;
- }
-
- @JsonDeserialize(using = LegacyTransformDeserializer.class)
- public static class TransformHelper { }
-
- public static class LegacyTransformDeserializer extends GenericLegacyDeserializer {
- public LegacyTransformDeserializer() {
- super(Transform.class, getLegacyMappingTransform());
- }
- }
-
- @JsonDeserialize(using = LegacyColumnAnalysisDeserializer.class)
- public static class ColumnAnalysisHelper { }
-
- public static class LegacyColumnAnalysisDeserializer extends GenericLegacyDeserializer {
- public LegacyColumnAnalysisDeserializer() {
- super(ColumnAnalysis.class, getLegacyMappingColumnAnalysis());
- }
- }
-
- @JsonDeserialize(using = LegacyConditionDeserializer.class)
- public static class ConditionHelper { }
-
- public static class LegacyConditionDeserializer extends GenericLegacyDeserializer {
- public LegacyConditionDeserializer() {
- super(Condition.class, getLegacyMappingCondition());
- }
- }
-
- @JsonDeserialize(using = LegacyFilterDeserializer.class)
- public static class FilterHelper { }
-
- public static class LegacyFilterDeserializer extends GenericLegacyDeserializer {
- public LegacyFilterDeserializer() {
- super(Filter.class, getLegacyMappingFilter());
- }
- }
-
- @JsonDeserialize(using = LegacyColumnMetaDataDeserializer.class)
- public static class ColumnMetaDataHelper { }
-
- public static class LegacyColumnMetaDataDeserializer extends GenericLegacyDeserializer {
- public LegacyColumnMetaDataDeserializer() {
- super(ColumnMetaData.class, getLegacyMappingColumnMetaData());
- }
- }
-
- @JsonDeserialize(using = LegacyCalculateSortedRankDeserializer.class)
- public static class CalculateSortedRankHelper { }
-
- public static class LegacyCalculateSortedRankDeserializer extends GenericLegacyDeserializer {
- public LegacyCalculateSortedRankDeserializer() {
- super(CalculateSortedRank.class, getLegacyMappingCalculateSortedRank());
- }
- }
-
- @JsonDeserialize(using = LegacySchemaDeserializer.class)
- public static class SchemaHelper { }
-
- public static class LegacySchemaDeserializer extends GenericLegacyDeserializer {
- public LegacySchemaDeserializer() {
- super(Schema.class, getLegacyMappingSchema());
- }
- }
-
- @JsonDeserialize(using = LegacySequenceComparatorDeserializer.class)
- public static class SequenceComparatorHelper { }
-
- public static class LegacySequenceComparatorDeserializer extends GenericLegacyDeserializer {
- public LegacySequenceComparatorDeserializer() {
- super(SequenceComparator.class, getLegacyMappingSequenceComparator());
- }
- }
-
- @JsonDeserialize(using = LegacySequenceSplitDeserializer.class)
- public static class SequenceSplitHelper { }
-
- public static class LegacySequenceSplitDeserializer extends GenericLegacyDeserializer {
- public LegacySequenceSplitDeserializer() {
- super(SequenceSplit.class, getLegacyMappingSequenceSplit());
- }
- }
-
- @JsonDeserialize(using = LegacyWindowFunctionDeserializer.class)
- public static class WindowFunctionHelper { }
-
- public static class LegacyWindowFunctionDeserializer extends GenericLegacyDeserializer {
- public LegacyWindowFunctionDeserializer() {
- super(WindowFunction.class, getLegacyMappingWindowFunction());
- }
- }
-
-
- @JsonDeserialize(using = LegacyIStringReducerDeserializer.class)
- public static class IStringReducerHelper { }
-
- public static class LegacyIStringReducerDeserializer extends GenericLegacyDeserializer {
- public LegacyIStringReducerDeserializer() {
- super(IStringReducer.class, getLegacyMappingIStringReducer());
- }
- }
-
-
- @JsonDeserialize(using = LegacyWritableDeserializer.class)
- public static class WritableHelper { }
-
- public static class LegacyWritableDeserializer extends GenericLegacyDeserializer {
- public LegacyWritableDeserializer() {
- super(Writable.class, getLegacyMappingWritable());
- }
- }
-
- @JsonDeserialize(using = LegacyWritableComparatorDeserializer.class)
- public static class WritableComparatorHelper { }
-
- public static class LegacyWritableComparatorDeserializer extends GenericLegacyDeserializer {
- public LegacyWritableComparatorDeserializer() {
- super(WritableComparator.class, getLegacyMappingWritableComparator());
- }
- }
-}
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java
index f43e189d8..54bd7b7c8 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java
@@ -17,7 +17,6 @@
package org.datavec.api.transform.stringreduce;
import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -31,8 +30,7 @@ import java.util.List;
* a single List
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.IStringReducerHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface IStringReducer extends Serializable {
/**
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java
index e3ba797c8..c55d4d3bb 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java
@@ -16,7 +16,7 @@
package org.datavec.api.util.ndarray;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import lombok.NonNull;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java
index 2584d8b9f..ae5e3a567 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java
@@ -17,7 +17,7 @@
package org.datavec.api.writable;
-import com.google.common.math.DoubleMath;
+import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java
index 72f4b3c5b..39f41c076 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java
@@ -17,7 +17,7 @@
package org.datavec.api.writable;
-import com.google.common.math.DoubleMath;
+import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java
index 1b54e7d54..f0ab62bef 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java
@@ -17,7 +17,7 @@
package org.datavec.api.writable;
-import com.google.common.math.DoubleMath;
+import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java
index 3803c8098..1c127e0f8 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java
@@ -17,7 +17,7 @@
package org.datavec.api.writable;
-import com.google.common.math.DoubleMath;
+import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java
index 58cd45829..4a767a183 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java
@@ -17,7 +17,7 @@
package org.datavec.api.writable;
-import com.google.common.math.DoubleMath;
+import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java
index 30eb5e25e..5085dd3f2 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java
@@ -16,7 +16,6 @@
package org.datavec.api.writable;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.DataInput;
@@ -60,8 +59,7 @@ import java.io.Serializable;
* }
*
*/
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.WritableHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface Writable extends Serializable {
/**
* Serialize the fields of this object to out
.
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java
index 4d638124e..0a5ddddb8 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java
@@ -16,7 +16,7 @@
package org.datavec.api.writable.batch;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import lombok.Data;
import lombok.NonNull;
import org.datavec.api.writable.NDArrayWritable;
diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java
index 07ef7ee56..b8e540f61 100644
--- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java
+++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java
@@ -16,16 +16,13 @@
package org.datavec.api.writable.comparator;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
-import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
import java.util.Comparator;
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyMappingHelper.WritableComparatorHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface WritableComparator extends Comparator, Serializable {
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java
index 647c0af65..e8ce37bd3 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java
@@ -16,7 +16,7 @@
package org.datavec.api.split;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java
index e67fc1f61..c9fb57eb9 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java
@@ -16,7 +16,7 @@
package org.datavec.api.split.parittion;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.datavec.api.conf.Configuration;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java
index c98b58381..dff90f8b9 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java
@@ -78,8 +78,9 @@ public class TestJsonYaml {
public void testMissingPrimitives() {
Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build();
-
- String strJson = "{\n" + " \"Schema\" : {\n" + " \"columns\" : [ {\n" + " \"Double\" : {\n"
+ //Legacy format JSON
+ String strJson = "{\n" + " \"Schema\" : {\n"
+ + " \"columns\" : [ {\n" + " \"Double\" : {\n"
+ " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" +
//" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test
//" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java
index ce7d779dc..6dfacdd93 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java
@@ -16,7 +16,7 @@
package org.datavec.api.writable;
-import com.google.common.collect.Lists;
+import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.junit.Test;
diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml
index 4d4381790..645971a45 100644
--- a/datavec/datavec-arrow/pom.xml
+++ b/datavec/datavec-arrow/pom.xml
@@ -34,36 +34,6 @@
nd4j-arrow
${project.version}
-
- com.fasterxml.jackson.core
- jackson-core
- ${spark2.jackson.version}
-
-
- com.fasterxml.jackson.core
- jackson-databind
- ${spark2.jackson.version}
-
-
- com.fasterxml.jackson.core
- jackson-annotations
- ${spark2.jackson.version}
-
-
- com.fasterxml.jackson.dataformat
- jackson-dataformat-yaml
- ${spark2.jackson.version}
-
-
- com.fasterxml.jackson.dataformat
- jackson-dataformat-xml
- ${spark2.jackson.version}
-
-
- com.fasterxml.jackson.datatype
- jackson-datatype-joda
- ${spark2.jackson.version}
-
org.datavec
datavec-api
@@ -74,11 +44,6 @@
hppc
${hppc.version}
-
- com.google.guava
- guava
- ${guava.version}
-
org.apache.arrow
arrow-vector
diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java
index d80f12a2e..a962dd84a 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java
@@ -16,7 +16,7 @@
package org.datavec.image.recordreader;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java
deleted file mode 100644
index 5e7b09c12..000000000
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java
+++ /dev/null
@@ -1,35 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.image.serde;
-
-import org.datavec.api.transform.serde.legacy.GenericLegacyDeserializer;
-import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
-import org.datavec.image.transform.ImageTransform;
-import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
-
-public class LegacyImageMappingHelper {
-
- @JsonDeserialize(using = LegacyImageTransformDeserializer.class)
- public static class ImageTransformHelper { }
-
- public static class LegacyImageTransformDeserializer extends GenericLegacyDeserializer {
- public LegacyImageTransformDeserializer() {
- super(ImageTransform.class, LegacyMappingHelper.getLegacyMappingImageTransform());
- }
- }
-
-}
diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java
index 39239494b..afcdf894f 100644
--- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java
+++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java
@@ -16,11 +16,8 @@
package org.datavec.image.transform;
-import lombok.Data;
import org.datavec.image.data.ImageWritable;
-import org.datavec.image.serde.LegacyImageMappingHelper;
import org.nd4j.shade.jackson.annotation.JsonInclude;
-import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.util.Random;
@@ -31,8 +28,7 @@ import java.util.Random;
* @author saudet
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyImageMappingHelper.ImageTransformHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ImageTransform {
/**
diff --git a/datavec/datavec-hadoop/pom.xml b/datavec/datavec-hadoop/pom.xml
index 38889228c..c95e6d3bc 100644
--- a/datavec/datavec-hadoop/pom.xml
+++ b/datavec/datavec-hadoop/pom.xml
@@ -50,11 +50,6 @@
netty
${netty.version}
-
- com.google.guava
- guava
- ${guava.version}
-
org.apache.commons
commons-compress
diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java
index 01bef8fa9..58f7a57db 100644
--- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java
+++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java
@@ -16,7 +16,7 @@
package org.datavec.hadoop.records.reader;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java
index cf1a801f5..1cbe47176 100644
--- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java
+++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java
@@ -16,7 +16,7 @@
package org.datavec.hadoop.records.reader;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java
index 992c91312..faf41cbb4 100644
--- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java
+++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java
@@ -16,7 +16,7 @@
package org.datavec.hadoop.records.reader;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java
index d35becd7a..7cd112c63 100644
--- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java
+++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java
@@ -16,7 +16,7 @@
package org.datavec.hadoop.records.writer;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.datavec.api.records.converter.RecordReaderConverter;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
diff --git a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java
index 5360eae7e..276d79f88 100644
--- a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java
+++ b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java
@@ -16,7 +16,7 @@
package org.datavec.local.transforms.join;
-import com.google.common.collect.Iterables;
+import org.nd4j.shade.guava.collect.Iterables;
import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.functions.FlatMapFunctionAdapter;
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
index 940ad01cc..605b13b70 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
@@ -64,12 +64,6 @@
${datavec.version}
-
- com.typesafe.akka
- akka-cluster_2.11
- ${akka.version}
-
-
joda-time
joda-time
@@ -106,40 +100,10 @@
${snakeyaml.version}
-
- com.fasterxml.jackson.core
- jackson-core
- ${jackson.version}
-
-
-
- com.fasterxml.jackson.core
- jackson-databind
- ${jackson.version}
-
-
-
- com.fasterxml.jackson.core
- jackson-annotations
- ${jackson.version}
-
-
-
- com.fasterxml.jackson.datatype
- jackson-datatype-jdk8
- ${jackson.version}
-
-
-
- com.fasterxml.jackson.datatype
- jackson-datatype-jsr310
- ${jackson.version}
-
-
com.typesafe.play
play-java_2.11
- ${play.version}
+ ${playframework.version}
com.google.code.findbugs
@@ -161,25 +125,31 @@
com.typesafe.play
play-json_2.11
- ${play.version}
+ ${playframework.version}
com.typesafe.play
play-server_2.11
- ${play.version}
+ ${playframework.version}
com.typesafe.play
play_2.11
- ${play.version}
+ ${playframework.version}
com.typesafe.play
play-netty-server_2.11
- ${play.version}
+ ${playframework.version}
+
+
+
+ com.typesafe.akka
+ akka-cluster_2.11
+ 2.5.23
@@ -194,6 +164,12 @@
jcommander
${jcommander.version}
+
+
+ org.apache.spark
+ spark-core_2.11
+ ${spark.version}
+
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java
index 893ce4218..f20799905 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java
@@ -24,12 +24,16 @@ import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.transform.model.*;
+import play.BuiltInComponents;
import play.Mode;
+import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
import java.io.File;
import java.io.IOException;
+import java.util.Base64;
+import java.util.Random;
import static play.mvc.Results.*;
@@ -66,9 +70,6 @@ public class CSVSparkTransformServer extends SparkTransformServer {
System.exit(1);
}
- RoutingDsl routingDsl = new RoutingDsl();
-
-
if (jsonPath != null) {
String json = FileUtils.readFileToString(new File(jsonPath));
TransformProcess transformProcess = TransformProcess.fromJson(json);
@@ -78,8 +79,26 @@ public class CSVSparkTransformServer extends SparkTransformServer {
+ "to /transformprocess");
}
+ //Set play secret key, if required
+ //http://www.playframework.com/documentation/latest/ApplicationSecret
+ String crypto = System.getProperty("play.crypto.secret");
+ if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
+ byte[] newCrypto = new byte[1024];
- routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> {
+ new Random().nextBytes(newCrypto);
+
+ String base64 = Base64.getEncoder().encodeToString(newCrypto);
+ System.setProperty("play.crypto.secret", base64);
+ }
+
+
+ server = Server.forRouter(Mode.PROD, port, this::createRouter);
+ }
+
+ protected Router createRouter(BuiltInComponents b){
+ RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
+
+ routingDsl.GET("/transformprocess").routingTo(req -> {
try {
if (transform == null)
return badRequest();
@@ -88,11 +107,11 @@ public class CSVSparkTransformServer extends SparkTransformServer {
log.error("Error in GET /transformprocess",e);
return internalServerError(e.getMessage());
}
- })));
+ });
- routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/transformprocess").routingTo(req -> {
try {
- TransformProcess transformProcess = TransformProcess.fromJson(getJsonText());
+ TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
setCSVTransformProcess(transformProcess);
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
@@ -100,12 +119,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
log.error("Error in POST /transformprocess",e);
return internalServerError(e.getMessage());
}
- })));
+ });
- routingDsl.POST("/transformincremental").routeTo(FunctionUtil.function0((() -> {
- if (isSequence()) {
+ routingDsl.POST("/transformincremental").routingTo(req -> {
+ if (isSequence(req)) {
try {
- BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
+ BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
@@ -115,7 +134,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
}
} else {
try {
- SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class);
+ SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
@@ -124,12 +143,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage());
}
}
- })));
+ });
- routingDsl.POST("/transform").routeTo(FunctionUtil.function0((() -> {
- if (isSequence()) {
+ routingDsl.POST("/transform").routingTo(req -> {
+ if (isSequence(req)) {
try {
- SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class));
+ SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
if (batch == null)
return badRequest();
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
@@ -139,7 +158,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
}
} else {
try {
- BatchCSVRecord input = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
+ BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
BatchCSVRecord batch = transform(input);
if (batch == null)
return badRequest();
@@ -149,14 +168,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage());
}
}
+ });
-
- })));
-
- routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> {
- if (isSequence()) {
+ routingDsl.POST("/transformincrementalarray").routingTo(req -> {
+ if (isSequence(req)) {
try {
- BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
+ BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
@@ -166,7 +183,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
}
} else {
try {
- SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class);
+ SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
@@ -175,13 +192,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage());
}
}
+ });
- })));
-
- routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> {
- if (isSequence()) {
+ routingDsl.POST("/transformarray").routingTo(req -> {
+ if (isSequence(req)) {
try {
- SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class);
+ SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
if (batchCSVRecord == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
@@ -191,7 +207,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
}
} else {
try {
- BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
+ BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (batchCSVRecord == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
@@ -200,10 +216,9 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage());
}
}
- })));
+ });
-
- server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
+ return routingDsl.build();
}
public static void main(String[] args) throws Exception {
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java
deleted file mode 100644
index 6c4874b02..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform;
-
-import play.libs.F;
-import play.mvc.Result;
-
-import java.util.function.Function;
-import java.util.function.Supplier;
-
-/**
- * Utility methods for Routing
- *
- * @author Alex Black
- */
-public class FunctionUtil {
-
-
- public static F.Function0 function0(Supplier supplier) {
- return supplier::get;
- }
-
- public static F.Function function(Function function) {
- return function::apply;
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java
index 29c1e1bd7..f8675f139 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java
@@ -24,8 +24,11 @@ import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.transform.model.*;
+import play.BuiltInComponents;
import play.Mode;
+import play.libs.Files;
import play.mvc.Http;
+import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
@@ -33,6 +36,7 @@ import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import java.util.function.Function;
import static play.mvc.Controller.request;
import static play.mvc.Results.*;
@@ -62,8 +66,6 @@ public class ImageSparkTransformServer extends SparkTransformServer {
System.exit(1);
}
- RoutingDsl routingDsl = new RoutingDsl();
-
if (jsonPath != null) {
String json = FileUtils.readFileToString(new File(jsonPath));
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
@@ -73,7 +75,13 @@ public class ImageSparkTransformServer extends SparkTransformServer {
+ "to /transformprocess");
}
- routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> {
+ server = Server.forRouter(Mode.PROD, port, this::createRouter);
+ }
+
+ protected Router createRouter(BuiltInComponents builtInComponents){
+ RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
+
+ routingDsl.GET("/transformprocess").routingTo(req -> {
try {
if (transform == null)
return badRequest();
@@ -83,11 +91,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace();
return internalServerError();
}
- })));
+ });
- routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/transformprocess").routingTo(req -> {
try {
- ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText());
+ ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
setImageTransformProcess(transformProcess);
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
@@ -95,11 +103,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace();
return internalServerError();
}
- })));
+ });
- routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/transformincrementalarray").routingTo(req -> {
try {
- SingleImageRecord record = objectMapper.readValue(getJsonText(), SingleImageRecord.class);
+ SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
@@ -107,17 +115,17 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace();
return internalServerError();
}
- })));
+ });
- routingDsl.POST("/transformincrementalimage").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/transformincrementalimage").routingTo(req -> {
try {
- Http.MultipartFormData body = request().body().asMultipartFormData();
- List files = body.getFiles();
- if (files.size() == 0 || files.get(0).getFile() == null) {
+ Http.MultipartFormData body = req.body().asMultipartFormData();
+ List> files = body.getFiles();
+ if (files.isEmpty() || files.get(0).getRef() == null ) {
return badRequest();
}
- File file = files.get(0).getFile();
+ File file = files.get(0).getRef().path().toFile();
SingleImageRecord record = new SingleImageRecord(file.toURI());
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
@@ -125,11 +133,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace();
return internalServerError();
}
- })));
+ });
- routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/transformarray").routingTo(req -> {
try {
- BatchImageRecord batch = objectMapper.readValue(getJsonText(), BatchImageRecord.class);
+ BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
if (batch == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
@@ -137,22 +145,22 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace();
return internalServerError();
}
- })));
+ });
- routingDsl.POST("/transformimage").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/transformimage").routingTo(req -> {
try {
- Http.MultipartFormData body = request().body().asMultipartFormData();
- List files = body.getFiles();
+ Http.MultipartFormData body = req.body().asMultipartFormData();
+ List> files = body.getFiles();
if (files.size() == 0) {
return badRequest();
}
List records = new ArrayList<>();
- for (Http.MultipartFormData.FilePart filePart : files) {
- File file = filePart.getFile();
+ for (Http.MultipartFormData.FilePart filePart : files) {
+ Files.TemporaryFile file = filePart.getRef();
if (file != null) {
- SingleImageRecord record = new SingleImageRecord(file.toURI());
+ SingleImageRecord record = new SingleImageRecord(file.path().toUri());
records.add(record);
}
}
@@ -164,9 +172,9 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace();
return internalServerError();
}
- })));
+ });
- server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
+ return routingDsl.build();
}
@Override
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java
index 2d4c92836..411872006 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java
@@ -22,6 +22,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody;
import org.datavec.spark.transform.model.BatchCSVRecord;
import org.datavec.spark.transform.service.DataVecTransformService;
import org.nd4j.shade.jackson.databind.ObjectMapper;
+import play.mvc.Http;
import play.server.Server;
import static play.mvc.Controller.request;
@@ -50,25 +51,17 @@ public abstract class SparkTransformServer implements DataVecTransformService {
server.stop();
}
- protected boolean isSequence() {
- return request().hasHeader(SEQUENCE_OR_NOT_HEADER)
- && request().getHeader(SEQUENCE_OR_NOT_HEADER).toUpperCase()
- .equals("TRUE");
+ protected boolean isSequence(Http.Request request) {
+ return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
+ && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
}
-
- protected String getHeaderValue(String value) {
- if (request().hasHeader(value))
- return request().getHeader(value);
- return null;
- }
-
- protected String getJsonText() {
- JsonNode tryJson = request().body().asJson();
+ protected String getJsonText(Http.Request request) {
+ JsonNode tryJson = request.body().asJson();
if (tryJson != null)
return tryJson.toString();
else
- return request().body().asText();
+ return request.body().asText();
}
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf
new file mode 100644
index 000000000..28a4aa208
--- /dev/null
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf
@@ -0,0 +1,350 @@
+# This is the main configuration file for the application.
+# https://www.playframework.com/documentation/latest/ConfigFile
+# ~~~~~
+# Play uses HOCON as its configuration file format. HOCON has a number
+# of advantages over other config formats, but there are two things that
+# can be used when modifying settings.
+#
+# You can include other configuration files in this main application.conf file:
+#include "extra-config.conf"
+#
+# You can declare variables and substitute for them:
+#mykey = ${some.value}
+#
+# And if an environment variable exists when there is no other subsitution, then
+# HOCON will fall back to substituting environment variable:
+#mykey = ${JAVA_HOME}
+
+## Akka
+# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
+# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
+# ~~~~~
+# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
+# other streaming HTTP responses.
+akka {
+ # "akka.log-config-on-start" is extraordinarly useful because it log the complete
+ # configuration at INFO level, including defaults and overrides, so it s worth
+ # putting at the very top.
+ #
+ # Put the following in your conf/logback.xml file:
+ #
+ #
+ #
+ # And then uncomment this line to debug the configuration.
+ #
+ #log-config-on-start = true
+}
+
+## Modules
+# https://www.playframework.com/documentation/latest/Modules
+# ~~~~~
+# Control which modules are loaded when Play starts. Note that modules are
+# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
+# Please see https://www.playframework.com/documentation/latest/GlobalSettings
+# for more information.
+#
+# You can also extend Play functionality by using one of the publically available
+# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
+play.modules {
+ # By default, Play will load any class called Module that is defined
+ # in the root package (the "app" directory), or you can define them
+ # explicitly below.
+ # If there are any built-in modules that you want to disable, you can list them here.
+ #enabled += my.application.Module
+
+ # If there are any built-in modules that you want to disable, you can list them here.
+ #disabled += ""
+}
+
+## Internationalisation
+# https://www.playframework.com/documentation/latest/JavaI18N
+# https://www.playframework.com/documentation/latest/ScalaI18N
+# ~~~~~
+# Play comes with its own i18n settings, which allow the user's preferred language
+# to map through to internal messages, or allow the language to be stored in a cookie.
+play.i18n {
+ # The application languages
+ langs = [ "en" ]
+
+ # Whether the language cookie should be secure or not
+ #langCookieSecure = true
+
+ # Whether the HTTP only attribute of the cookie should be set to true
+ #langCookieHttpOnly = true
+}
+
+## Play HTTP settings
+# ~~~~~
+play.http {
+ ## Router
+ # https://www.playframework.com/documentation/latest/JavaRouting
+ # https://www.playframework.com/documentation/latest/ScalaRouting
+ # ~~~~~
+ # Define the Router object to use for this application.
+ # This router will be looked up first when the application is starting up,
+ # so make sure this is the entry point.
+ # Furthermore, it's assumed your route file is named properly.
+ # So for an application router like `my.application.Router`,
+ # you may need to define a router file `conf/my.application.routes`.
+ # Default to Routes in the root package (aka "apps" folder) (and conf/routes)
+ #router = my.application.Router
+
+ ## Action Creator
+ # https://www.playframework.com/documentation/latest/JavaActionCreator
+ # ~~~~~
+ #actionCreator = null
+
+ ## ErrorHandler
+ # https://www.playframework.com/documentation/latest/JavaRouting
+ # https://www.playframework.com/documentation/latest/ScalaRouting
+ # ~~~~~
+ # If null, will attempt to load a class called ErrorHandler in the root package,
+ #errorHandler = null
+
+ ## Filters
+ # https://www.playframework.com/documentation/latest/ScalaHttpFilters
+ # https://www.playframework.com/documentation/latest/JavaHttpFilters
+ # ~~~~~
+ # Filters run code on every request. They can be used to perform
+ # common logic for all your actions, e.g. adding common headers.
+ # Defaults to "Filters" in the root package (aka "apps" folder)
+ # Alternatively you can explicitly register a class here.
+ #filters += my.application.Filters
+
+ ## Session & Flash
+ # https://www.playframework.com/documentation/latest/JavaSessionFlash
+ # https://www.playframework.com/documentation/latest/ScalaSessionFlash
+ # ~~~~~
+ session {
+ # Sets the cookie to be sent only over HTTPS.
+ #secure = true
+
+ # Sets the cookie to be accessed only by the server.
+ #httpOnly = true
+
+ # Sets the max-age field of the cookie to 5 minutes.
+ # NOTE: this only sets when the browser will discard the cookie. Play will consider any
+ # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
+ # you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
+ #maxAge = 300
+
+ # Sets the domain on the session cookie.
+ #domain = "example.com"
+ }
+
+ flash {
+ # Sets the cookie to be sent only over HTTPS.
+ #secure = true
+
+ # Sets the cookie to be accessed only by the server.
+ #httpOnly = true
+ }
+}
+
+## Netty Provider
+# https://www.playframework.com/documentation/latest/SettingsNetty
+# ~~~~~
+play.server.netty {
+ # Whether the Netty wire should be logged
+ #log.wire = true
+
+ # If you run Play on Linux, you can use Netty's native socket transport
+ # for higher performance with less garbage.
+ #transport = "native"
+}
+
+## WS (HTTP Client)
+# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
+# ~~~~~
+# The HTTP client primarily used for REST APIs. The default client can be
+# configured directly, but you can also create different client instances
+# with customized settings. You must enable this by adding to build.sbt:
+#
+# libraryDependencies += ws // or javaWs if using java
+#
+play.ws {
+ # Sets HTTP requests not to follow 302 requests
+ #followRedirects = false
+
+ # Sets the maximum number of open HTTP connections for the client.
+ #ahc.maxConnectionsTotal = 50
+
+ ## WS SSL
+ # https://www.playframework.com/documentation/latest/WsSSL
+ # ~~~~~
+ ssl {
+ # Configuring HTTPS with Play WS does not require programming. You can
+ # set up both trustManager and keyManager for mutual authentication, and
+ # turn on JSSE debugging in development with a reload.
+ #debug.handshake = true
+ #trustManager = {
+ # stores = [
+ # { type = "JKS", path = "exampletrust.jks" }
+ # ]
+ #}
+ }
+}
+
+## Cache
+# https://www.playframework.com/documentation/latest/JavaCache
+# https://www.playframework.com/documentation/latest/ScalaCache
+# ~~~~~
+# Play comes with an integrated cache API that can reduce the operational
+# overhead of repeated requests. You must enable this by adding to build.sbt:
+#
+# libraryDependencies += cache
+#
+play.cache {
+ # If you want to bind several caches, you can bind the individually
+ #bindCaches = ["db-cache", "user-cache", "session-cache"]
+}
+
+## Filters
+# https://www.playframework.com/documentation/latest/Filters
+# ~~~~~
+# There are a number of built-in filters that can be enabled and configured
+# to give Play greater security. You must enable this by adding to build.sbt:
+#
+# libraryDependencies += filters
+#
+play.filters {
+ ## CORS filter configuration
+ # https://www.playframework.com/documentation/latest/CorsFilter
+ # ~~~~~
+ # CORS is a protocol that allows web applications to make requests from the browser
+ # across different domains.
+ # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
+ # dependencies on CORS settings.
+ cors {
+ # Filter paths by a whitelist of path prefixes
+ #pathPrefixes = ["/some/path", ...]
+
+ # The allowed origins. If null, all origins are allowed.
+ #allowedOrigins = ["http://www.example.com"]
+
+ # The allowed HTTP methods. If null, all methods are allowed
+ #allowedHttpMethods = ["GET", "POST"]
+ }
+
+ ## CSRF Filter
+ # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
+ # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
+ # ~~~~~
+ # Play supports multiple methods for verifying that a request is not a CSRF request.
+ # The primary mechanism is a CSRF token. This token gets placed either in the query string
+ # or body of every form submitted, and also gets placed in the users session.
+ # Play then verifies that both tokens are present and match.
+ csrf {
+ # Sets the cookie to be sent only over HTTPS
+ #cookie.secure = true
+
+ # Defaults to CSRFErrorHandler in the root package.
+ #errorHandler = MyCSRFErrorHandler
+ }
+
+ ## Security headers filter configuration
+ # https://www.playframework.com/documentation/latest/SecurityHeaders
+ # ~~~~~
+ # Defines security headers that prevent XSS attacks.
+ # If enabled, then all options are set to the below configuration by default:
+ headers {
+ # The X-Frame-Options header. If null, the header is not set.
+ #frameOptions = "DENY"
+
+ # The X-XSS-Protection header. If null, the header is not set.
+ #xssProtection = "1; mode=block"
+
+ # The X-Content-Type-Options header. If null, the header is not set.
+ #contentTypeOptions = "nosniff"
+
+ # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
+ #permittedCrossDomainPolicies = "master-only"
+
+ # The Content-Security-Policy header. If null, the header is not set.
+ #contentSecurityPolicy = "default-src 'self'"
+ }
+
+ ## Allowed hosts filter configuration
+ # https://www.playframework.com/documentation/latest/AllowedHostsFilter
+ # ~~~~~
+ # Play provides a filter that lets you configure which hosts can access your application.
+ # This is useful to prevent cache poisoning attacks.
+ hosts {
+ # Allow requests to example.com, its subdomains, and localhost:9000.
+ #allowed = [".example.com", "localhost:9000"]
+ }
+}
+
+## Evolutions
+# https://www.playframework.com/documentation/latest/Evolutions
+# ~~~~~
+# Evolutions allows database scripts to be automatically run on startup in dev mode
+# for database migrations. You must enable this by adding to build.sbt:
+#
+# libraryDependencies += evolutions
+#
+play.evolutions {
+ # You can disable evolutions for a specific datasource if necessary
+ #db.default.enabled = false
+}
+
+## Database Connection Pool
+# https://www.playframework.com/documentation/latest/SettingsJDBC
+# ~~~~~
+# Play doesn't require a JDBC database to run, but you can easily enable one.
+#
+# libraryDependencies += jdbc
+#
+play.db {
+ # The combination of these two settings results in "db.default" as the
+ # default JDBC pool:
+ #config = "db"
+ #default = "default"
+
+ # Play uses HikariCP as the default connection pool. You can override
+ # settings by changing the prototype:
+ prototype {
+ # Sets a fixed JDBC connection pool size of 50
+ #hikaricp.minimumIdle = 50
+ #hikaricp.maximumPoolSize = 50
+ }
+}
+
+## JDBC Datasource
+# https://www.playframework.com/documentation/latest/JavaDatabase
+# https://www.playframework.com/documentation/latest/ScalaDatabase
+# ~~~~~
+# Once JDBC datasource is set up, you can work with several different
+# database options:
+#
+# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
+# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
+# EBean: https://playframework.com/documentation/latest/JavaEbean
+# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
+#
+db {
+ # You can declare as many datasources as you want.
+ # By convention, the default datasource is named `default`
+
+ # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
+ default.driver = org.h2.Driver
+ default.url = "jdbc:h2:mem:play"
+ #default.username = sa
+ #default.password = ""
+
+ # You can expose this datasource via JNDI if needed (Useful for JPA)
+ default.jndiName=DefaultDS
+
+ # You can turn on SQL logging for any datasource
+ # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
+ #default.logSql=true
+}
+
+jpa.default=defaultPersistenceUnit
+
+
+#Increase default maximum post length - used for remote listener functionality
+#Can get response 413 with larger networks without setting this
+# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
+#parsers.text.maxLength=10M
+play.http.parser.maxMemoryBuffer=10M
diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml
index d98730407..05c505cac 100644
--- a/datavec/datavec-spark/pom.xml
+++ b/datavec/datavec-spark/pom.xml
@@ -28,61 +28,11 @@
datavec-spark_2.11
-
- 2.1.0
- 2
-
2.11.12
2.11
-
-
-
-
- org.codehaus.mojo
- build-helper-maven-plugin
-
-
- add-source
- generate-sources
-
- add-source
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- com.fasterxml.jackson.datatype
- jackson-datatype-jsr310
- ${jackson.version}
-
-
- com.fasterxml.jackson.dataformat
- jackson-dataformat-yaml
- ${jackson.version}
-
-
- com.fasterxml.jackson.module
- jackson-module-scala_2.11
- ${jackson.version}
-
-
-
-
org.scala-lang
@@ -95,42 +45,13 @@
${scala.version}
-
- org.codehaus.jackson
- jackson-core-asl
- ${jackson-asl.version}
-
-
- org.codehaus.jackson
- jackson-mapper-asl
- ${jackson-asl.version}
-
org.apache.spark
spark-sql_2.11
${spark.version}
+ provided
-
- com.google.guava
- guava
- ${guava.version}
-
-
- com.google.inject
- guice
- ${guice.version}
-
-
- com.google.protobuf
- protobuf-java
- ${google.protobuf.version}
-
-
- commons-codec
- commons-codec
- ${commons-codec.version}
-
commons-collections
commons-collections
@@ -141,96 +62,16 @@
commons-io
${commons-io.version}
-
- commons-lang
- commons-lang
- ${commons-lang.version}
-
-
- commons-net
- commons-net
- ${commons-net.version}
-
-
- com.sun.xml.bind
- jaxb-core
- ${jaxb.version}
-
-
- com.sun.xml.bind
- jaxb-impl
- ${jaxb.version}
-
-
- com.typesafe.akka
- akka-actor_2.11
- ${akka.version}
-
-
- com.typesafe.akka
- akka-remote_2.11
- ${akka.version}
-
-
- com.typesafe.akka
- akka-slf4j_2.11
- ${akka.version}
-
-
- io.netty
- netty
- ${netty.version}
-
-
- com.fasterxml.jackson.core
- jackson-core
- ${jackson.version}
-
-
- com.fasterxml.jackson.core
- jackson-databind
- ${jackson.version}
-
-
- com.fasterxml.jackson.core
- jackson-annotations
- ${jackson.version}
-
-
- javax.servlet
- javax.servlet-api
- ${servlet.version}
-
-
- org.apache.commons
- commons-compress
- ${commons-compress.version}
-
-
- org.apache.commons
- commons-lang3
- ${commons-lang3.version}
-
org.apache.commons
commons-math3
${commons-math3.version}
-
- org.apache.curator
- curator-recipes
- ${curator.version}
-
org.slf4j
slf4j-api
${slf4j.version}
-
- com.typesafe
- config
- ${typesafe.config.version}
-
org.apache.spark
spark-core_2.11
@@ -241,14 +82,6 @@
com.google.code.findbugs
jsr305
-
- org.slf4j
- slf4j-log4j12
-
-
- log4j
- log4j
-
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java
deleted file mode 100644
index 8aeae58a5..000000000
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java
+++ /dev/null
@@ -1,29 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.functions;
-
-import java.io.Serializable;
-
-/**
- *
- * A function that returns zero or more output records from each input record.
- *
- * Adapter for Spark interface in order to freeze interface changes between spark versions
- */
-public interface FlatMapFunctionAdapter extends Serializable {
- Iterable call(T t) throws Exception;
-}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java
index fbe8a63d4..5d0bff7f3 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java
@@ -21,10 +21,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.sql.Column;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.functions;
+import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -46,7 +43,6 @@ import java.util.List;
import static org.apache.spark.sql.functions.avg;
import static org.apache.spark.sql.functions.col;
-import static org.datavec.spark.transform.DataRowsFacade.dataRows;
/**
@@ -71,7 +67,7 @@ public class DataFrames {
* deviation for
* @return the column that represents the standard deviation
*/
- public static Column std(DataRowsFacade dataFrame, String columnName) {
+ public static Column std(Dataset dataFrame, String columnName) {
return functions.sqrt(var(dataFrame, columnName));
}
@@ -85,8 +81,8 @@ public class DataFrames {
* deviation for
* @return the column that represents the standard deviation
*/
- public static Column var(DataRowsFacade dataFrame, String columnName) {
- return dataFrame.get().groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
+ public static Column var(Dataset dataFrame, String columnName) {
+ return dataFrame.groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
}
/**
@@ -97,8 +93,8 @@ public class DataFrames {
* @param columnName the name of the column to get the min for
* @return the column that represents the min
*/
- public static Column min(DataRowsFacade dataFrame, String columnName) {
- return dataFrame.get().groupBy(columnName).agg(functions.min(columnName)).col(columnName);
+ public static Column min(Dataset dataFrame, String columnName) {
+ return dataFrame.groupBy(columnName).agg(functions.min(columnName)).col(columnName);
}
/**
@@ -110,8 +106,8 @@ public class DataFrames {
* to get the max for
* @return the column that represents the max
*/
- public static Column max(DataRowsFacade dataFrame, String columnName) {
- return dataFrame.get().groupBy(columnName).agg(functions.max(columnName)).col(columnName);
+ public static Column max(Dataset dataFrame, String columnName) {
+ return dataFrame.groupBy(columnName).agg(functions.max(columnName)).col(columnName);
}
/**
@@ -122,8 +118,8 @@ public class DataFrames {
* @param columnName the name of the column to get the mean for
* @return the column that represents the mean
*/
- public static Column mean(DataRowsFacade dataFrame, String columnName) {
- return dataFrame.get().groupBy(columnName).agg(avg(columnName)).col(columnName);
+ public static Column mean(Dataset dataFrame, String columnName) {
+ return dataFrame.groupBy(columnName).agg(avg(columnName)).col(columnName);
}
/**
@@ -166,7 +162,7 @@ public class DataFrames {
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
* of this record in the original time series.
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
- * using {@link #toRecordsSequence(DataRowsFacade)}
+ * using {@link #toRecordsSequence(Dataset)}
*
* @param schema Schema to convert
* @return StructType for the schema
@@ -250,9 +246,9 @@ public class DataFrames {
* @param dataFrame the dataframe to convert
* @return the converted schema and rdd of writables
*/
- public static Pair>> toRecords(DataRowsFacade dataFrame) {
- Schema schema = fromStructType(dataFrame.get().schema());
- return new Pair<>(schema, dataFrame.get().javaRDD().map(new ToRecord(schema)));
+ public static Pair>> toRecords(Dataset dataFrame) {
+ Schema schema = fromStructType(dataFrame.schema());
+ return new Pair<>(schema, dataFrame.javaRDD().map(new ToRecord(schema)));
}
/**
@@ -267,11 +263,11 @@ public class DataFrames {
* @param dataFrame Data frame to convert
* @return Data in sequence (i.e., {@code List>} form
*/
- public static Pair>>> toRecordsSequence(DataRowsFacade dataFrame) {
+ public static Pair>>> toRecordsSequence(Dataset dataFrame) {
//Need to convert from flattened to sequence data...
//First: Group by the Sequence UUID (first column)
- JavaPairRDD> grouped = dataFrame.get().javaRDD().groupBy(new Function() {
+ JavaPairRDD> grouped = dataFrame.javaRDD().groupBy(new Function() {
@Override
public String call(Row row) throws Exception {
return row.getString(0);
@@ -279,7 +275,7 @@ public class DataFrames {
});
- Schema schema = fromStructType(dataFrame.get().schema());
+ Schema schema = fromStructType(dataFrame.schema());
//Group by sequence UUID, and sort each row within the sequences using the time step index
Function, List>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner
@@ -318,11 +314,11 @@ public class DataFrames {
* @param data the data to convert
* @return the dataframe object
*/
- public static DataRowsFacade toDataFrame(Schema schema, JavaRDD> data) {
+ public static Dataset toDataFrame(Schema schema, JavaRDD> data) {
JavaSparkContext sc = new JavaSparkContext(data.context());
SQLContext sqlContext = new SQLContext(sc);
JavaRDD rows = data.map(new ToRow(schema));
- return dataRows(sqlContext.createDataFrame(rows, fromSchema(schema)));
+ return sqlContext.createDataFrame(rows, fromSchema(schema));
}
@@ -333,18 +329,18 @@ public class DataFrames {
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
* of this record in the original time series.
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
- * using {@link #toRecordsSequence(DataRowsFacade)}
+ * using {@link #toRecordsSequence(Dataset)}
*
* @param schema Schema for the data
* @param data Sequence data to convert to a DataFrame
* @return The dataframe object
*/
- public static DataRowsFacade toDataFrameSequence(Schema schema, JavaRDD>> data) {
+ public static Dataset toDataFrameSequence(Schema schema, JavaRDD>> data) {
JavaSparkContext sc = new JavaSparkContext(data.context());
SQLContext sqlContext = new SQLContext(sc);
JavaRDD rows = data.flatMap(new SequenceToRows(schema));
- return dataRows(sqlContext.createDataFrame(rows, fromSchemaSequence(schema)));
+ return sqlContext.createDataFrame(rows, fromSchemaSequence(schema));
}
/**
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java
index 68efd1888..cacea101d 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java
@@ -19,14 +19,13 @@ package org.datavec.spark.transform;
import org.apache.commons.collections.map.ListOrderedMap;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import java.util.*;
-import static org.datavec.spark.transform.DataRowsFacade.dataRows;
-
/**
* Simple dataframe based normalization.
@@ -46,7 +45,7 @@ public class Normalization {
* @return a zero mean unit variance centered
* rdd
*/
- public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame) {
+ public static Dataset zeromeanUnitVariance(Dataset frame) {
return zeromeanUnitVariance(frame, Collections.emptyList());
}
@@ -71,7 +70,7 @@ public class Normalization {
* @param max the maximum value
* @return the normalized dataframe per column
*/
- public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max) {
+ public static Dataset normalize(Dataset dataFrame, double min, double max) {
return normalize(dataFrame, min, max, Collections.emptyList());
}
@@ -86,7 +85,7 @@ public class Normalization {
*/
public static JavaRDD> normalize(Schema schema, JavaRDD> data, double min,
double max) {
- DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
+ Dataset frame = DataFrames.toDataFrame(schema, data);
return DataFrames.toRecords(normalize(frame, min, max, Collections.emptyList())).getSecond();
}
@@ -97,7 +96,7 @@ public class Normalization {
* @param dataFrame the dataframe to scale
* @return the normalized dataframe per column
*/
- public static DataRowsFacade normalize(DataRowsFacade dataFrame) {
+ public static Dataset normalize(Dataset dataFrame) {
return normalize(dataFrame, 0, 1, Collections.emptyList());
}
@@ -120,8 +119,8 @@ public class Normalization {
* @return a zero mean unit variance centered
* rdd
*/
- public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame, List skipColumns) {
- List columnsList = DataFrames.toList(frame.get().columns());
+ public static Dataset zeromeanUnitVariance(Dataset frame, List skipColumns) {
+ List columnsList = DataFrames.toList(frame.columns());
columnsList.removeAll(skipColumns);
String[] columnNames = DataFrames.toArray(columnsList);
//first row is std second row is mean, each column in a row is for a particular column
@@ -133,7 +132,7 @@ public class Normalization {
if (std == 0.0)
std = 1; //All same value -> (x-x)/1 = 0
- frame = dataRows(frame.get().withColumn(columnName, frame.get().col(columnName).minus(mean).divide(std)));
+ frame = frame.withColumn(columnName, frame.col(columnName).minus(mean).divide(std));
}
@@ -152,7 +151,7 @@ public class Normalization {
*/
public static JavaRDD> zeromeanUnitVariance(Schema schema, JavaRDD> data,
List skipColumns) {
- DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
+ Dataset frame = DataFrames.toDataFrame(schema, data);
return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond();
}
@@ -178,7 +177,7 @@ public class Normalization {
*/
public static JavaRDD>> zeroMeanUnitVarianceSequence(Schema schema,
JavaRDD>> sequence, List excludeColumns) {
- DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, sequence);
+ Dataset frame = DataFrames.toDataFrameSequence(schema, sequence);
if (excludeColumns == null)
excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN);
else {
@@ -196,7 +195,7 @@ public class Normalization {
* @param columns the columns to get the
* @return
*/
- public static List minMaxColumns(DataRowsFacade data, List columns) {
+ public static List minMaxColumns(Dataset data, List columns) {
String[] arr = new String[columns.size()];
for (int i = 0; i < arr.length; i++)
arr[i] = columns.get(i);
@@ -210,7 +209,7 @@ public class Normalization {
* @param columns the columns to get the
* @return
*/
- public static List minMaxColumns(DataRowsFacade data, String... columns) {
+ public static List minMaxColumns(Dataset data, String... columns) {
return aggregate(data, columns, new String[] {"min", "max"});
}
@@ -221,7 +220,7 @@ public class Normalization {
* @param columns the columns to get the
* @return
*/
- public static List stdDevMeanColumns(DataRowsFacade data, List columns) {
+ public static List stdDevMeanColumns(Dataset data, List columns) {
String[] arr = new String[columns.size()];
for (int i = 0; i < arr.length; i++)
arr[i] = columns.get(i);
@@ -237,7 +236,7 @@ public class Normalization {
* @param columns the columns to get the
* @return
*/
- public static List stdDevMeanColumns(DataRowsFacade data, String... columns) {
+ public static List stdDevMeanColumns(Dataset data, String... columns) {
return aggregate(data, columns, new String[] {"stddev", "mean"});
}
@@ -251,7 +250,7 @@ public class Normalization {
* Each row will be a function with the desired columnar output
* in the order in which the columns were specified.
*/
- public static List aggregate(DataRowsFacade data, String[] columns, String[] functions) {
+ public static List aggregate(Dataset data, String[] columns, String[] functions) {
String[] rest = new String[columns.length - 1];
System.arraycopy(columns, 1, rest, 0, rest.length);
List rows = new ArrayList<>();
@@ -262,8 +261,8 @@ public class Normalization {
}
//compute the aggregation based on the operation
- DataRowsFacade aggregated = dataRows(data.get().agg(expressions));
- String[] columns2 = aggregated.get().columns();
+ Dataset aggregated = data.agg(expressions);
+ String[] columns2 = aggregated.columns();
//strip out the op name and parentheses from the columns
Map opReplace = new TreeMap<>();
for (String s : columns2) {
@@ -278,20 +277,20 @@ public class Normalization {
//get rid of the operation name in the column
- DataRowsFacade rearranged = null;
+ Dataset rearranged = null;
for (Map.Entry entries : opReplace.entrySet()) {
//first column
if (rearranged == null) {
- rearranged = dataRows(aggregated.get().withColumnRenamed(entries.getKey(), entries.getValue()));
+ rearranged = aggregated.withColumnRenamed(entries.getKey(), entries.getValue());
}
//rearranged is just a copy of aggregated at this point
else
- rearranged = dataRows(rearranged.get().withColumnRenamed(entries.getKey(), entries.getValue()));
+ rearranged = rearranged.withColumnRenamed(entries.getKey(), entries.getValue());
}
- rearranged = dataRows(rearranged.get().select(DataFrames.toColumns(columns)));
+ rearranged = rearranged.select(DataFrames.toColumns(columns));
//op
- rows.addAll(rearranged.get().collectAsList());
+ rows.addAll(rearranged.collectAsList());
}
@@ -307,8 +306,8 @@ public class Normalization {
* @param max the maximum value
* @return the normalized dataframe per column
*/
- public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max, List skipColumns) {
- List columnsList = DataFrames.toList(dataFrame.get().columns());
+ public static Dataset normalize(Dataset dataFrame, double min, double max, List skipColumns) {
+ List columnsList = DataFrames.toList(dataFrame.columns());
columnsList.removeAll(skipColumns);
String[] columnNames = DataFrames.toArray(columnsList);
//first row is min second row is max, each column in a row is for a particular column
@@ -321,8 +320,8 @@ public class Normalization {
if (maxSubMin == 0)
maxSubMin = 1;
- Column newCol = dataFrame.get().col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min);
- dataFrame = dataRows(dataFrame.get().withColumn(columnName, newCol));
+ Column newCol = dataFrame.col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min);
+ dataFrame = dataFrame.withColumn(columnName, newCol);
}
@@ -340,7 +339,7 @@ public class Normalization {
*/
public static JavaRDD> normalize(Schema schema, JavaRDD> data, double min, double max,
List skipColumns) {
- DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
+ Dataset frame = DataFrames.toDataFrame(schema, data);
return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond();
}
@@ -387,7 +386,7 @@ public class Normalization {
excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN);
excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN);
}
- DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, data);
+ Dataset frame = DataFrames.toDataFrameSequence(schema, data);
return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond();
}
@@ -398,7 +397,7 @@ public class Normalization {
* @param dataFrame the dataframe to scale
* @return the normalized dataframe per column
*/
- public static DataRowsFacade normalize(DataRowsFacade dataFrame, List skipColumns) {
+ public static Dataset normalize(Dataset dataFrame, List skipColumns) {
return normalize(dataFrame, 0, 1, skipColumns);
}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java
index 6b7ff203f..5052491bb 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java
@@ -16,9 +16,10 @@
package org.datavec.spark.transform.analysis;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.writable.Writable;
-import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
+import java.util.Iterator;
import java.util.List;
/**
@@ -27,10 +28,11 @@ import java.util.List;
*
* @author Alex Black
*/
-public class SequenceFlatMapFunction extends BaseFlatMapFunctionAdaptee>, List> {
+public class SequenceFlatMapFunction implements FlatMapFunction>, List> {
- public SequenceFlatMapFunction() {
- super(new SequenceFlatMapFunctionAdapter());
+ @Override
+ public Iterator> call(List> collections) throws Exception {
+ return collections.iterator();
}
}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java
deleted file mode 100644
index 6b25fb826..000000000
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform.analysis;
-
-import org.datavec.api.writable.Writable;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-
-import java.util.List;
-
-/**
- * SequenceFlatMapFunction: very simple function used to flatten a sequence
- * Typically used only internally for certain analysis operations
- *
- * @author Alex Black
- */
-public class SequenceFlatMapFunctionAdapter implements FlatMapFunctionAdapter>, List> {
- @Override
- public Iterable> call(List> collections) throws Exception {
- return collections;
- }
-
-}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java
index 52bf924be..6e501f560 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java
@@ -16,11 +16,14 @@
package org.datavec.spark.transform.join;
+import org.nd4j.shade.guava.collect.Iterables;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable;
-import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import scala.Tuple2;
+import java.util.ArrayList;
+import java.util.Iterator;
import java.util.List;
/**
@@ -28,10 +31,89 @@ import java.util.List;
*
* @author Alex Black
*/
-public class ExecuteJoinFromCoGroupFlatMapFunction extends
- BaseFlatMapFunctionAdaptee, Tuple2>, Iterable>>>, List> {
+public class ExecuteJoinFromCoGroupFlatMapFunction implements FlatMapFunction, Tuple2>, Iterable>>>, List> {
+
+ private final Join join;
public ExecuteJoinFromCoGroupFlatMapFunction(Join join) {
- super(new ExecuteJoinFromCoGroupFlatMapFunctionAdapter(join));
+ this.join = join;
+ }
+
+ @Override
+ public Iterator> call(
+ Tuple2, Tuple2>, Iterable>>> t2)
+ throws Exception {
+
+ Iterable> leftList = t2._2()._1();
+ Iterable> rightList = t2._2()._2();
+
+ List> ret = new ArrayList<>();
+ Join.JoinType jt = join.getJoinType();
+ switch (jt) {
+ case Inner:
+ //Return records where key columns appear in BOTH
+ //So if no values from left OR right: no return values
+ for (List jvl : leftList) {
+ for (List jvr : rightList) {
+ List joined = join.joinExamples(jvl, jvr);
+ ret.add(joined);
+ }
+ }
+ break;
+ case LeftOuter:
+ //Return all records from left, even if no corresponding right value (NullWritable in that case)
+ for (List jvl : leftList) {
+ if (Iterables.size(rightList) == 0) {
+ List joined = join.joinExamples(jvl, null);
+ ret.add(joined);
+ } else {
+ for (List jvr : rightList) {
+ List joined = join.joinExamples(jvl, jvr);
+ ret.add(joined);
+ }
+ }
+ }
+ break;
+ case RightOuter:
+ //Return all records from right, even if no corresponding left value (NullWritable in that case)
+ for (List jvr : rightList) {
+ if (Iterables.size(leftList) == 0) {
+ List joined = join.joinExamples(null, jvr);
+ ret.add(joined);
+ } else {
+ for (List jvl : leftList) {
+ List joined = join.joinExamples(jvl, jvr);
+ ret.add(joined);
+ }
+ }
+ }
+ break;
+ case FullOuter:
+ //Return all records, even if no corresponding left/right value (NullWritable in that case)
+ if (Iterables.size(leftList) == 0) {
+ //Only right values
+ for (List jvr : rightList) {
+ List joined = join.joinExamples(null, jvr);
+ ret.add(joined);
+ }
+ } else if (Iterables.size(rightList) == 0) {
+ //Only left values
+ for (List jvl : leftList) {
+ List joined = join.joinExamples(jvl, null);
+ ret.add(joined);
+ }
+ } else {
+ //Records from both left and right
+ for (List jvl : leftList) {
+ for (List jvr : rightList) {
+ List joined = join.joinExamples(jvl, jvr);
+ ret.add(joined);
+ }
+ }
+ }
+ break;
+ }
+
+ return ret.iterator();
}
}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java
deleted file mode 100644
index dedff46d0..000000000
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java
+++ /dev/null
@@ -1,119 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform.join;
-
-import com.google.common.collect.Iterables;
-import org.datavec.api.transform.join.Join;
-import org.datavec.api.writable.Writable;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-import scala.Tuple2;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * Execute a join
- *
- * @author Alex Black
- */
-public class ExecuteJoinFromCoGroupFlatMapFunctionAdapter implements
- FlatMapFunctionAdapter, Tuple2>, Iterable>>>, List> {
-
- private final Join join;
-
- public ExecuteJoinFromCoGroupFlatMapFunctionAdapter(Join join) {
- this.join = join;
- }
-
- @Override
- public Iterable> call(
- Tuple2, Tuple2>, Iterable>>> t2)
- throws Exception {
-
- Iterable> leftList = t2._2()._1();
- Iterable> rightList = t2._2()._2();
-
- List> ret = new ArrayList<>();
- Join.JoinType jt = join.getJoinType();
- switch (jt) {
- case Inner:
- //Return records where key columns appear in BOTH
- //So if no values from left OR right: no return values
- for (List jvl : leftList) {
- for (List jvr : rightList) {
- List joined = join.joinExamples(jvl, jvr);
- ret.add(joined);
- }
- }
- break;
- case LeftOuter:
- //Return all records from left, even if no corresponding right value (NullWritable in that case)
- for (List jvl : leftList) {
- if (Iterables.size(rightList) == 0) {
- List joined = join.joinExamples(jvl, null);
- ret.add(joined);
- } else {
- for (List jvr : rightList) {
- List joined = join.joinExamples(jvl, jvr);
- ret.add(joined);
- }
- }
- }
- break;
- case RightOuter:
- //Return all records from right, even if no corresponding left value (NullWritable in that case)
- for (List jvr : rightList) {
- if (Iterables.size(leftList) == 0) {
- List joined = join.joinExamples(null, jvr);
- ret.add(joined);
- } else {
- for (List jvl : leftList) {
- List joined = join.joinExamples(jvl, jvr);
- ret.add(joined);
- }
- }
- }
- break;
- case FullOuter:
- //Return all records, even if no corresponding left/right value (NullWritable in that case)
- if (Iterables.size(leftList) == 0) {
- //Only right values
- for (List jvr : rightList) {
- List joined = join.joinExamples(null, jvr);
- ret.add(joined);
- }
- } else if (Iterables.size(rightList) == 0) {
- //Only left values
- for (List jvl : leftList) {
- List joined = join.joinExamples(jvl, null);
- ret.add(joined);
- }
- } else {
- //Records from both left and right
- for (List jvl : leftList) {
- for (List jvr : rightList) {
- List joined = join.joinExamples(jvl, jvr);
- ret.add(joined);
- }
- }
- }
- break;
- }
-
- return ret;
- }
-}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java
index 6e206a657..d4ede4808 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java
@@ -16,10 +16,12 @@
package org.datavec.spark.transform.join;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable;
-import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
+import java.util.Collections;
+import java.util.Iterator;
import java.util.List;
/**
@@ -29,10 +31,43 @@ import java.util.List;
*
* @author Alex Black
*/
-public class FilterAndFlattenJoinedValues extends BaseFlatMapFunctionAdaptee> {
+public class FilterAndFlattenJoinedValues implements FlatMapFunction> {
+
+ private final Join.JoinType joinType;
public FilterAndFlattenJoinedValues(Join.JoinType joinType) {
- super(new FilterAndFlattenJoinedValuesAdapter(joinType));
+ this.joinType = joinType;
+ }
+
+ @Override
+ public Iterator> call(JoinedValue joinedValue) throws Exception {
+ boolean keep;
+ switch (joinType) {
+ case Inner:
+ //Only keep joined values where we have both left and right
+ keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
+ break;
+ case LeftOuter:
+ //Keep all values where left is not missing/null
+ keep = joinedValue.isHaveLeft();
+ break;
+ case RightOuter:
+ //Keep all values where right is not missing/null
+ keep = joinedValue.isHaveRight();
+ break;
+ case FullOuter:
+ //Keep all values
+ keep = true;
+ break;
+ default:
+ throw new RuntimeException("Unknown/not implemented join type: " + joinType);
+ }
+
+ if (keep) {
+ return Collections.singletonList(joinedValue.getValues()).iterator();
+ } else {
+ return Collections.emptyIterator();
+ }
}
}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java
deleted file mode 100644
index 3333276b1..000000000
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform.join;
-
-import org.datavec.api.transform.join.Join;
-import org.datavec.api.writable.Writable;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-
-import java.util.Collections;
-import java.util.List;
-
-/**
- * Doing two things here:
- * (a) filter out any unnecessary values, and
- * (b) extract the List values from the JoinedValue
- *
- * @author Alex Black
- */
-public class FilterAndFlattenJoinedValuesAdapter implements FlatMapFunctionAdapter> {
-
- private final Join.JoinType joinType;
-
- public FilterAndFlattenJoinedValuesAdapter(Join.JoinType joinType) {
- this.joinType = joinType;
- }
-
- @Override
- public Iterable> call(JoinedValue joinedValue) throws Exception {
- boolean keep;
- switch (joinType) {
- case Inner:
- //Only keep joined values where we have both left and right
- keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
- break;
- case LeftOuter:
- //Keep all values where left is not missing/null
- keep = joinedValue.isHaveLeft();
- break;
- case RightOuter:
- //Keep all values where right is not missing/null
- keep = joinedValue.isHaveRight();
- break;
- case FullOuter:
- //Keep all values
- keep = true;
- break;
- default:
- throw new RuntimeException("Unknown/not implemented join type: " + joinType);
- }
-
- if (keep) {
- return Collections.singletonList(joinedValue.getValues());
- } else {
- return Collections.emptyList();
- }
- }
-}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java
index 28c91c84b..639e43836 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java
@@ -16,21 +16,69 @@
package org.datavec.spark.transform.sparkfunction;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
+import org.apache.spark.sql.types.StructType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
-import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
+import org.datavec.spark.transform.DataFrames;
-import java.util.List;
+import java.util.*;
/**
* Convert a record to a row
* @author Adam Gibson
*/
-public class SequenceToRows extends BaseFlatMapFunctionAdaptee>, Row> {
+public class SequenceToRows implements FlatMapFunction>, Row> {
+
+ private Schema schema;
+ private StructType structType;
public SequenceToRows(Schema schema) {
- super(new SequenceToRowsAdapter(schema));
+ this.schema = schema;
+ structType = DataFrames.fromSchemaSequence(schema);
}
+
+ @Override
+ public Iterator call(List> sequence) throws Exception {
+ if (sequence.size() == 0)
+ return Collections.emptyIterator();
+
+ String sequenceUUID = UUID.randomUUID().toString();
+
+ List out = new ArrayList<>(sequence.size());
+
+ int stepCount = 0;
+ for (List step : sequence) {
+ Object[] values = new Object[step.size() + 2];
+ values[0] = sequenceUUID;
+ values[1] = stepCount++;
+ for (int i = 0; i < step.size(); i++) {
+ switch (schema.getColumnTypes().get(i)) {
+ case Double:
+ values[i + 2] = step.get(i).toDouble();
+ break;
+ case Integer:
+ values[i + 2] = step.get(i).toInt();
+ break;
+ case Long:
+ values[i + 2] = step.get(i).toLong();
+ break;
+ case Float:
+ values[i + 2] = step.get(i).toFloat();
+ break;
+ default:
+ throw new IllegalStateException(
+ "This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
+ }
+ }
+
+ Row row = new GenericRowWithSchema(values, structType);
+ out.add(row);
+ }
+
+ return out.iterator();
+ }
}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java
deleted file mode 100644
index 2ca2f32ae..000000000
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform.sparkfunction;
-
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
-import org.apache.spark.sql.types.StructType;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.writable.Writable;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-import org.datavec.spark.transform.DataFrames;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.UUID;
-
-/**
- * Convert a record to a row
- * @author Adam Gibson
- */
-public class SequenceToRowsAdapter implements FlatMapFunctionAdapter>, Row> {
-
- private Schema schema;
- private StructType structType;
-
- public SequenceToRowsAdapter(Schema schema) {
- this.schema = schema;
- structType = DataFrames.fromSchemaSequence(schema);
- }
-
-
- @Override
- public Iterable call(List> sequence) throws Exception {
- if (sequence.size() == 0)
- return Collections.emptyList();
-
- String sequenceUUID = UUID.randomUUID().toString();
-
- List out = new ArrayList<>(sequence.size());
-
- int stepCount = 0;
- for (List step : sequence) {
- Object[] values = new Object[step.size() + 2];
- values[0] = sequenceUUID;
- values[1] = stepCount++;
- for (int i = 0; i < step.size(); i++) {
- switch (schema.getColumnTypes().get(i)) {
- case Double:
- values[i + 2] = step.get(i).toDouble();
- break;
- case Integer:
- values[i + 2] = step.get(i).toInt();
- break;
- case Long:
- values[i + 2] = step.get(i).toLong();
- break;
- case Float:
- values[i + 2] = step.get(i).toFloat();
- break;
- default:
- throw new IllegalStateException(
- "This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
- }
- }
-
- Row row = new GenericRowWithSchema(values, structType);
- out.add(row);
- }
-
- return out;
- }
-}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java
index 41981a736..1a8782dfb 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java
@@ -16,19 +16,27 @@
package org.datavec.spark.transform.transform;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.writable.Writable;
-import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
+import java.util.Iterator;
import java.util.List;
/**
* Created by Alex on 17/03/2016.
*/
-public class SequenceSplitFunction extends BaseFlatMapFunctionAdaptee>, List>> {
+public class SequenceSplitFunction implements FlatMapFunction>, List>> {
+
+ private final SequenceSplit split;
public SequenceSplitFunction(SequenceSplit split) {
- super(new SequenceSplitFunctionAdapter(split));
+ this.split = split;
+ }
+
+ @Override
+ public Iterator>> call(List> collections) throws Exception {
+ return split.split(collections).iterator();
}
}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java
deleted file mode 100644
index 5bde7ee62..000000000
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform.transform;
-
-import org.datavec.api.transform.sequence.SequenceSplit;
-import org.datavec.api.writable.Writable;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-
-import java.util.List;
-
-/**
- * Created by Alex on 17/03/2016.
- */
-public class SequenceSplitFunctionAdapter
- implements FlatMapFunctionAdapter>, List>> {
-
- private final SequenceSplit split;
-
- public SequenceSplitFunctionAdapter(SequenceSplit split) {
- this.split = split;
- }
-
- @Override
- public Iterable>> call(List> collections) throws Exception {
- return split.split(collections);
- }
-}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java
index ffe3f80c6..81f07b1f4 100644
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java
+++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java
@@ -16,19 +16,32 @@
package org.datavec.spark.transform.transform;
+import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.writable.Writable;
-import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
+import java.util.Collections;
+import java.util.Iterator;
import java.util.List;
/**
* Spark function for executing a transform process
*/
-public class SparkTransformProcessFunction extends BaseFlatMapFunctionAdaptee, List> {
+public class SparkTransformProcessFunction implements FlatMapFunction, List> {
+
+ private final TransformProcess transformProcess;
public SparkTransformProcessFunction(TransformProcess transformProcess) {
- super(new SparkTransformProcessFunctionAdapter(transformProcess));
+ this.transformProcess = transformProcess;
+ }
+
+ @Override
+ public Iterator> call(List v1) throws Exception {
+ List newList = transformProcess.execute(v1);
+ if (newList == null)
+ return Collections.emptyIterator(); //Example was filtered out
+ else
+ return Collections.singletonList(newList).iterator();
}
}
diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java
deleted file mode 100644
index 7b1766cc2..000000000
--- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform.transform;
-
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.writable.Writable;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-
-import java.util.Collections;
-import java.util.List;
-
-/**
- * Spark function for executing a transform process
- */
-public class SparkTransformProcessFunctionAdapter implements FlatMapFunctionAdapter, List> {
-
- private final TransformProcess transformProcess;
-
- public SparkTransformProcessFunctionAdapter(TransformProcess transformProcess) {
- this.transformProcess = transformProcess;
- }
-
- @Override
- public Iterable> call(List v1) throws Exception {
- List newList = transformProcess.execute(v1);
- if (newList == null)
- return Collections.emptyList(); //Example was filtered out
- else
- return Collections.singletonList(newList);
- }
-}
diff --git a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java b/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java
deleted file mode 100644
index af600a14d..000000000
--- a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform;
-
-import org.apache.spark.api.java.function.FlatMapFunction;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-
-/**
- * FlatMapFunction adapter to
- * hide incompatibilities between Spark 1.x and Spark 2.x
- *
- * This class should be used instead of direct referral to FlatMapFunction
- *
- */
-public class BaseFlatMapFunctionAdaptee implements FlatMapFunction {
-
- protected final FlatMapFunctionAdapter adapter;
-
- public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) {
- this.adapter = adapter;
- }
-
- @Override
- public Iterable call(K k) throws Exception {
- return adapter.call(k);
- }
-}
diff --git a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java b/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java
deleted file mode 100644
index 0ad7c55bd..000000000
--- a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform;
-
-import org.apache.spark.sql.DataFrame;
-
-/**
- * Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
- *
- * This class should be used instead of direct referral to DataFrame / Dataset
- *
- */
-public class DataRowsFacade {
-
- private final DataFrame df;
-
- private DataRowsFacade(DataFrame df) {
- this.df = df;
- }
-
- public static DataRowsFacade dataRows(DataFrame df) {
- return new DataRowsFacade(df);
- }
-
- public DataFrame get() {
- return df;
- }
-}
diff --git a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java b/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java
deleted file mode 100644
index f30e5a222..000000000
--- a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform;
-
-import org.apache.spark.api.java.function.FlatMapFunction;
-import org.datavec.spark.functions.FlatMapFunctionAdapter;
-
-import java.util.Iterator;
-
-/**
- * FlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x
- *
- * This class should be used instead of direct referral to FlatMapFunction
- *
- */
-public class BaseFlatMapFunctionAdaptee implements FlatMapFunction {
-
- protected final FlatMapFunctionAdapter adapter;
-
- public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) {
- this.adapter = adapter;
- }
-
- @Override
- public Iterator call(K k) throws Exception {
- return adapter.call(k).iterator();
- }
-}
diff --git a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java b/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java
deleted file mode 100644
index 9958a622e..000000000
--- a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.datavec.spark.transform;
-
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-
-/**
- * Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
- *
- * This class should be used instead of direct referral to DataFrame / Dataset
- *
- */
-public class DataRowsFacade {
-
- private final Dataset df;
-
- private DataRowsFacade(Dataset df) {
- this.df = df;
- }
-
- public static DataRowsFacade dataRows(Dataset df) {
- return new DataRowsFacade(df);
- }
-
- public Dataset get() {
- return df;
- }
-}
diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java
index 49a3946e2..8f0247568 100644
--- a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java
+++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java
@@ -16,7 +16,7 @@
package org.datavec.spark.storage;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.writable.*;
diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java
index 5b6ff6342..a19725a2a 100644
--- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java
+++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java
@@ -19,6 +19,8 @@ package org.datavec.spark.transform;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
@@ -46,9 +48,9 @@ public class DataFramesTests extends BaseSparkTest {
for (int i = 0; i < numColumns; i++)
builder.addColumnDouble(String.valueOf(i));
Schema schema = builder.build();
- DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records));
- dataFrame.get().show();
- dataFrame.get().describe(DataFrames.toArray(schema.getColumnNames())).show();
+ Dataset dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records));
+ dataFrame.show();
+ dataFrame.describe(DataFrames.toArray(schema.getColumnNames())).show();
// System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames()));
// System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames()));
@@ -77,12 +79,12 @@ public class DataFramesTests extends BaseSparkTest {
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
- DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
- dataFrame.get().show();
+ Dataset dataFrame = DataFrames.toDataFrame(schema, rdd);
+ dataFrame.show();
Column mean = DataFrames.mean(dataFrame, "0");
Column std = DataFrames.std(dataFrame, "0");
- dataFrame.get().withColumn("0", dataFrame.get().col("0").minus(mean)).show();
- dataFrame.get().withColumn("0", dataFrame.get().col("0").divide(std)).show();
+ dataFrame.withColumn("0", dataFrame.col("0").minus(mean)).show();
+ dataFrame.withColumn("0", dataFrame.col("0").divide(std)).show();
/* DataFrame desc = dataFrame.describe(dataFrame.columns());
dataFrame.show();
diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java
index 3b1b1e1a6..5352ec10d 100644
--- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java
+++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java
@@ -17,6 +17,7 @@
package org.datavec.spark.transform;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
@@ -24,11 +25,13 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.Test;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
+import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
@@ -50,36 +53,35 @@ public class NormalizationTests extends BaseSparkTest {
for (int i = 0; i < numColumns; i++)
builder.addColumnDouble(String.valueOf(i));
+ Nd4j.getRandom().setSeed(12345);
+
+ INDArray arr = Nd4j.rand(DataType.FLOAT, 5, numColumns);
for (int i = 0; i < 5; i++) {
List record = new ArrayList<>(numColumns);
data.add(record);
for (int j = 0; j < numColumns; j++) {
- record.add(new DoubleWritable(1.0));
+ record.add(new DoubleWritable(arr.getDouble(i, j)));
}
-
}
- INDArray arr = RecordConverter.toMatrix(data);
Schema schema = builder.build();
JavaRDD> rdd = sc.parallelize(data);
- DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
+ Dataset dataFrame = DataFrames.toDataFrame(schema, rdd);
//assert equivalent to the ndarray pre processing
- NormalizerStandardize standardScaler = new NormalizerStandardize();
- standardScaler.fit(new DataSet(arr.dup(), arr.dup()));
- INDArray standardScalered = arr.dup();
- standardScaler.transform(new DataSet(standardScalered, standardScalered));
DataNormalization zeroToOne = new NormalizerMinMaxScaler();
zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
INDArray zeroToOnes = arr.dup();
zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
- List rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.get().columns());
+ List rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.columns());
INDArray assertion = DataFrames.toMatrix(rows);
- //compare standard deviation
- assertTrue(standardScaler.getStd().equalsWithEps(assertion.getRow(0), 1e-1));
+ INDArray expStd = arr.std(true, true, 0);
+ INDArray std = assertion.getRow(0, true);
+ assertTrue(expStd.equalsWithEps(std, 1e-3));
//compare mean
- assertTrue(standardScaler.getMean().equalsWithEps(assertion.getRow(1), 1e-1));
+ INDArray expMean = arr.mean(true, 0);
+ assertTrue(expMean.equalsWithEps(assertion.getRow(1, true), 1e-3));
}
@@ -109,10 +111,10 @@ public class NormalizationTests extends BaseSparkTest {
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
- DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
- dataFrame.get().show();
- Normalization.zeromeanUnitVariance(dataFrame).get().show();
- Normalization.normalize(dataFrame).get().show();
+ Dataset dataFrame = DataFrames.toDataFrame(schema, rdd);
+ dataFrame.show();
+ Normalization.zeromeanUnitVariance(dataFrame).show();
+ Normalization.normalize(dataFrame).show();
//assert equivalent to the ndarray pre processing
NormalizerStandardize standardScaler = new NormalizerStandardize();
diff --git a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java
index 7b0d9d3c1..04773ee74 100644
--- a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java
+++ b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java
@@ -52,18 +52,6 @@ public class DL4JSystemProperties {
*/
public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl";
- /**
- * Applicability: deeplearning4j-nn
- * Description: Used for loading legacy format JSON containing custom layers. This system property is provided as an
- * alternative to {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}. Classes are specified in
- * comma-separated format.
- * This is required ONLY when ALL of the following conditions are met:
- * 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND
- * 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J), AND
- * 3. You haven't already called {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}
- */
- public static final String CUSTOM_REGISTRATION_PROPERTY = "org.deeplearning4j.config.custom.legacyclasses";
-
/**
* Applicability: deeplearning4j-nn
* Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled
diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml
index aeb5fe04b..81142fd68 100644
--- a/deeplearning4j/deeplearning4j-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-core/pom.xml
@@ -96,18 +96,6 @@
test
-
- com.google.guava
- guava
- ${guava.version}
-
-
- com.google.code.findbugs
- jsr305
-
-
-
-
org.nd4j
nd4j-api
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java
index 56317b946..3bd1bd37f 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.datasets.datavec;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java
index 6ff639cb4..1e82a4783 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java
@@ -17,7 +17,7 @@
package org.deeplearning4j.datasets.datavec;
-import com.google.common.io.Files;
+import org.nd4j.shade.guava.io.Files;
import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
index 935d83218..52d3b0774 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java
@@ -16,8 +16,8 @@
package org.deeplearning4j.nn.dtypes;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.reflect.ClassPath;
+import org.nd4j.shade.guava.collect.ImmutableSet;
+import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
@@ -103,7 +103,7 @@ public class DTypeTests extends BaseDL4JTest {
ImmutableSet info;
try {
//Dependency note: this ClassPath class was added in Guava 14
- info = com.google.common.reflect.ClassPath.from(DTypeTests.class.getClassLoader())
+ info = org.nd4j.shade.guava.reflect.ClassPath.from(DTypeTests.class.getClassLoader())
.getTopLevelClassesRecursive("org.deeplearning4j");
} catch (IOException e) {
//Should never happen
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java
index fb5836a99..11e45c51d 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java
@@ -229,7 +229,9 @@ public class TestRnnLayers extends BaseDL4JTest {
net.fit(in,l);
} catch (Throwable t){
String msg = t.getMessage();
- assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
+ if(msg == null)
+ t.printStackTrace();
+ assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
}
}
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java
index 77fc63aa0..aba60fe4f 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.plot;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;
diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
index f112d4386..a66914cd7 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java
@@ -22,23 +22,24 @@ import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ConvolutionMode;
-import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
-import org.deeplearning4j.nn.conf.layers.*;
+import org.deeplearning4j.nn.conf.layers.BatchNormalization;
+import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
+import org.deeplearning4j.nn.conf.layers.GravesLSTM;
+import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitXavier;
-import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer;
-import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
-import org.nd4j.linalg.activations.impl.*;
+import org.nd4j.linalg.activations.impl.ActivationIdentity;
+import org.nd4j.linalg.activations.impl.ActivationLReLU;
+import org.nd4j.linalg.activations.impl.ActivationSoftmax;
+import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.regularization.WeightDecay;
@@ -60,6 +61,9 @@ public class RegressionTest100a extends BaseDL4JTest {
@Test
public void testCustomLayer() throws Exception {
+ //We dropped support for 1.0.0-alpha and earlier custom layers due to the maintenance overhead for a rarely used feature
+ //An upgrade path exists as a workaround - load in beta to beta4 and re-save
+ //All built-in layers can be loaded going back to 0.5.0
File f = Resources.asFile("regression_testing/100a/CustomLayerExample_100a.bin");
@@ -68,67 +72,8 @@ public class RegressionTest100a extends BaseDL4JTest {
fail("Expected exception");
} catch (Exception e){
String msg = e.getMessage();
- assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON"));
+ assertTrue(msg, msg.contains("custom") && msg.contains("1.0.0-beta") && msg.contains("saved again"));
}
-
- NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class);
-
- MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
-
- DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer();
- assertEquals(new ActivationTanH(), l0.getActivationFn());
- assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0));
- assertEquals(new RmsProp(0.95), l0.getIUpdater());
-
- CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer();
- assertEquals(new ActivationTanH(), l1.getActivationFn());
- assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
- assertEquals(new RmsProp(0.95), l1.getIUpdater());
-
-
- INDArray outExp;
- File f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Output_100a.bin");
- try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){
- outExp = Nd4j.read(dis);
- }
-
- INDArray in;
- File f3 = Resources.asFile("regression_testing/100a/CustomLayerExample_Input_100a.bin");
- try(DataInputStream dis = new DataInputStream(new FileInputStream(f3))){
- in = Nd4j.read(dis);
- }
-
- INDArray outAct = net.output(in);
-
- assertEquals(outExp, outAct);
-
-
- //Check graph
- f = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_100a.bin");
-
- //Deregister custom class:
- new LegacyLayerDeserializer().getLegacyNamesMap().remove("CustomLayer");
-
- try {
- ComputationGraph.load(f, true);
- fail("Expected exception");
- } catch (Exception e){
- String msg = e.getMessage();
- assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON"));
- }
-
- NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class);
-
- ComputationGraph graph = ComputationGraph.load(f, true);
-
- f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_Output_100a.bin");
- try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){
- outExp = Nd4j.read(dis);
- }
-
- outAct = graph.outputSingle(in);
-
- assertEquals(outExp, outAct);
}
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java
index e7fb7246a..592aefcd6 100755
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java
@@ -16,8 +16,8 @@
package org.deeplearning4j.datasets.iterator;
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Lists;
+import org.nd4j.shade.guava.annotations.VisibleForTesting;
+import org.nd4j.shade.guava.collect.Lists;
import lombok.Getter;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java
index 955732c07..02f2c4eb0 100644
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.datasets.iterator.parallel;
-import com.google.common.collect.Lists;
+import org.nd4j.shade.guava.collect.Lists;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java
index dd83c1bd4..7eca0fac0 100644
--- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java
+++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java
@@ -17,7 +17,7 @@
package org.deeplearning4j.plot;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Setter;
diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java
index 41b50795e..9efb88e24 100644
--- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java
+++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.plot;
-import com.google.common.primitives.Ints;
+import org.nd4j.shade.guava.primitives.Ints;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dimensionalityreduction.PCA;
diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml
index 59f05a6e8..dec29266f 100644
--- a/deeplearning4j/deeplearning4j-modelimport/pom.xml
+++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml
@@ -37,6 +37,12 @@
${nd4j.version}
+
+ com.google.code.gson
+ gson
+ ${gson.version}
+
+
org.deeplearning4j
deeplearning4j-nn
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
index 5351f0955..7477c7794 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
@@ -77,26 +77,11 @@
${project.version}
-
- com.google.guava
- guava
- ${guava.version}
-
com.google.protobuf
protobuf-java
${google.protobuf.version}
-
- com.typesafe.akka
- akka-actor_2.11
- ${akka.version}
-
-
- com.typesafe.akka
- akka-slf4j_2.11
- ${akka.version}
-
joda-time
joda-time
@@ -213,11 +198,6 @@
play-netty-server_2.11
${playframework.version}
-
- com.typesafe.akka
- akka-cluster_2.11
- ${akka.version}
-
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java
deleted file mode 100644
index df178fd70..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.deeplearning4j.nearestneighbor.server;
-
-import play.libs.F;
-import play.mvc.Result;
-
-import java.util.function.Function;
-import java.util.function.Supplier;
-
-/**
- * Utility methods for Routing
- *
- * @author Alex Black
- */
-public class FunctionUtil {
-
-
- public static F.Function0 function0(Supplier supplier) {
- return supplier::get;
- }
-
- public static F.Function function(Function function) {
- return function::apply;
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
index 58682d5e1..a79b57b19 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java
@@ -33,8 +33,10 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde;
+import play.BuiltInComponents;
import play.Mode;
import play.libs.Json;
+import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
@@ -149,19 +151,36 @@ public class NearestNeighborsServer {
VPTree tree = new VPTree(points, similarityFunction, invert);
- RoutingDsl routingDsl = new RoutingDsl();
+ //Set play secret key, if required
+ //http://www.playframework.com/documentation/latest/ApplicationSecret
+ String crypto = System.getProperty("play.crypto.secret");
+ if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
+ byte[] newCrypto = new byte[1024];
+
+ new Random().nextBytes(newCrypto);
+
+ String base64 = Base64.getEncoder().encodeToString(newCrypto);
+ System.setProperty("play.crypto.secret", base64);
+ }
+
+
+ server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
+ }
+
+ protected Router createRouter(VPTree tree, List labels, INDArray points, BuiltInComponents builtInComponents){
+ RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
//return the host information for a given id
- routingDsl.POST("/knn").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/knn").routingTo(request -> {
try {
- NearestNeighborRequest record = Json.fromJson(request().body().asJson(), NearestNeighborRequest.class);
+ NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
NearestNeighbor nearestNeighbor =
- NearestNeighbor.builder().points(points).record(record).tree(tree).build();
+ NearestNeighbor.builder().points(points).record(record).tree(tree).build();
if (record == null)
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
NearestNeighborsResults results =
- NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
+ NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
return ok(Json.toJson(results));
@@ -171,11 +190,11 @@ public class NearestNeighborsServer {
e.printStackTrace();
return internalServerError(e.getMessage());
}
- })));
+ });
- routingDsl.POST("/knnnew").routeTo(FunctionUtil.function0((() -> {
+ routingDsl.POST("/knnnew").routingTo(request -> {
try {
- Base64NDArrayBody record = Json.fromJson(request().body().asJson(), Base64NDArrayBody.class);
+ Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
if (record == null)
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
@@ -216,23 +235,9 @@ public class NearestNeighborsServer {
e.printStackTrace();
return internalServerError(e.getMessage());
}
- })));
-
- //Set play secret key, if required
- //http://www.playframework.com/documentation/latest/ApplicationSecret
- String crypto = System.getProperty("play.crypto.secret");
- if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
- byte[] newCrypto = new byte[1024];
-
- new Random().nextBytes(newCrypto);
-
- String base64 = Base64.getEncoder().encodeToString(newCrypto);
- System.setProperty("play.crypto.secret", base64);
- }
-
- server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
-
+ });
+ return routingDsl.build();
}
/**
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
index 8b774cd56..f95f9268d 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
@@ -59,6 +59,13 @@
${project.version}
test
+
+
+ joda-time
+ joda-time
+ 2.10.3
+ test
+
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java
index 1c57bc38a..cae103f10 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java
@@ -16,8 +16,8 @@
package org.deeplearning4j.clustering.info;
-import com.google.common.collect.HashBasedTable;
-import com.google.common.collect.Table;
+import org.nd4j.shade.guava.collect.HashBasedTable;
+import org.nd4j.shade.guava.collect.Table;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java
index c2154e6ba..f1cc2e304 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.clustering.quadtree;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java
index 11746f4c2..5c31ee78a 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.clustering.randomprojection;
-import com.google.common.primitives.Doubles;
+import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java
index 83af0365a..659f334df 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.clustering.sptree;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.val;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.nn.conf.WorkspaceMode;
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java
index 4b7b8e567..1de7a379b 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.clustering.kdtree;
-import com.google.common.primitives.Doubles;
+import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val;
import org.deeplearning4j.clustering.BaseDL4JTest;
import org.joda.time.Duration;
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java
index 03ad90748..f5ee19403 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.clustering.sptree;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.apache.commons.lang3.time.StopWatch;
import org.deeplearning4j.clustering.BaseDL4JTest;
import org.junit.Before;
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java
index 76658438a..5edb3926a 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java
@@ -18,12 +18,10 @@ package org.deeplearning4j.clustering.vptree;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
-import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.BaseDL4JTest;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.joda.time.Duration;
import org.junit.BeforeClass;
-import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -33,7 +31,6 @@ import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.primitives.Pair;
import java.util.*;
-import java.util.concurrent.TimeUnit;
import static org.junit.Assert.*;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java
index 67169132c..03d2462a5 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.text.corpora.sentiwordnet;
-import com.google.common.collect.Sets;
+import org.nd4j.shade.guava.collect.Sets;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.CASException;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
index d2d752509..f4dd1a6c5 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java
@@ -16,8 +16,8 @@
package org.deeplearning4j.models;
-import com.google.common.io.Files;
-import com.google.common.primitives.Doubles;
+import org.nd4j.shade.guava.io.Files;
+import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.ArrayUtils;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java
index 61e31b3c7..01b38a644 100755
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java
@@ -16,8 +16,8 @@
package org.deeplearning4j.models.word2vec;
-import com.google.common.primitives.Doubles;
-import com.google.common.primitives.Ints;
+import org.nd4j.shade.guava.primitives.Doubles;
+import org.nd4j.shade.guava.primitives.Ints;
import lombok.val;
import net.didion.jwnl.data.Word;
import org.apache.commons.io.FileUtils;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java
index 5be56964c..579caa0a3 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.models.embeddings.inmemory;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java
index df502aded..84fc17b7e 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.models.embeddings.reader.impl;
-import com.google.common.collect.Lists;
+import org.nd4j.shade.guava.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NonNull;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java
index 75511cae1..f71c56717 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.models.embeddings.wordvectors;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java
index f97aee9c0..8680a809c 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.models.glove.count;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.primitives.Pair;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java
index 4c05b7cc7..64ee79dd4 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.models.paragraphvectors;
-import com.google.common.collect.Lists;
+import org.nd4j.shade.guava.collect.Lists;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import lombok.Getter;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
index 78a878930..87dd0880a 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java
@@ -16,8 +16,8 @@
package org.deeplearning4j.models.sequencevectors;
-import com.google.common.primitives.Ints;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.primitives.Ints;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java
index d5e789ebf..99263f6bc 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.models.sequencevectors.sequence;
-import com.google.common.util.concurrent.AtomicDouble;
+import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java
index cdb4b6c9e..5e51cef20 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.text.invertedindex;
-import com.google.common.base.Function;
+import org.nd4j.shade.guava.base.Function;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.primitives.Pair;
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java
index 41536ff70..f5e0ca388 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.models.embeddings.wordvectors;
-import com.google.common.collect.Lists;
+import org.nd4j.shade.guava.collect.Lists;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java
index 5834e1647..69767df13 100755
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java
@@ -16,8 +16,8 @@
package org.deeplearning4j.eval;
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
+import org.nd4j.shade.guava.collect.HashMultiset;
+import org.nd4j.shade.guava.collect.Multiset;
import lombok.Getter;
import java.io.Serializable;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java
index e964cd9b0..2b00ac375 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.eval.curves;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java
index b66230ddd..5d5e65c2a 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java
@@ -16,7 +16,7 @@
package org.deeplearning4j.eval.curves;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.shade.jackson.annotation.JsonProperty;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java
index 6cd8f06b3..5a5ce5665 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java
@@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
+import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.OutputLayerUtil;
@@ -40,6 +41,7 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
+import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -172,6 +174,26 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
ComputationGraphConfiguration conf;
try {
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
+ } catch (InvalidTypeIdException e){
+ if(e.getMessage().contains("@class")){
+ try{
+ //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
+ return JsonMappers.getLegacyMapper().readValue(json, ComputationGraphConfiguration.class);
+ } catch (InvalidTypeIdException e2){
+ //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
+ //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
+ String msg = e2.getMessage();
+ if(msg != null && msg.contains("Could not resolve type id")){
+ throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom " +
+ "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" +
+ " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e);
+ }
+ throw new RuntimeException(e2);
+ } catch (IOException e2){
+ throw new RuntimeException(e2);
+ }
+ }
+ throw new RuntimeException(e);
} catch (Exception e) {
//Check if this exception came from legacy deserializer...
String msg = e.getMessage();
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java
index ab7fd044b..4f02e3d66 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java
@@ -19,7 +19,6 @@ package org.deeplearning4j.nn.conf;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.inputs.InputType;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyPreprocessorDeserializerHelper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@@ -34,8 +33,7 @@ import java.io.Serializable;
*
* @author Adam Gibson
*/
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyPreprocessorDeserializerHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface InputPreProcessor extends Serializable, Cloneable {
/**
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java
index de3373323..52a67ef16 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java
@@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
+import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
@@ -41,6 +42,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
+import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.databind.node.ArrayNode;
import java.io.IOException;
@@ -157,6 +159,26 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
ObjectMapper mapper = NeuralNetConfiguration.mapper();
try {
conf = mapper.readValue(json, MultiLayerConfiguration.class);
+ } catch (InvalidTypeIdException e){
+ if(e.getMessage().contains("@class")){
+ try {
+ //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
+ return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class);
+ } catch (InvalidTypeIdException e2){
+ //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
+ //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
+ String msg = e2.getMessage();
+ if(msg != null && msg.contains("Could not resolve type id")){
+ throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " +
+ "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" +
+ " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e);
+ }
+ throw new RuntimeException(e2);
+ } catch (IOException e2){
+ throw new RuntimeException(e2);
+ }
+ }
+ throw new RuntimeException(e);
} catch (IOException e) {
//Check if this exception came from legacy deserializer...
String msg = e.getMessage();
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java
index d8e77a55a..0da5bea13 100755
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java
@@ -26,20 +26,14 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout;
-import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
-import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializer;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyPreprocessorDeserializer;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyReconstructionDistributionDeserializer;
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.nn.weights.IWeightInit;
@@ -59,10 +53,6 @@ import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
-import org.nd4j.linalg.lossfunctions.ILossFunction;
-import org.nd4j.linalg.primitives.Pair;
-import org.nd4j.serde.json.LegacyIActivationDeserializer;
-import org.nd4j.serde.json.LegacyILossFunctionDeserializer;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.IOException;
@@ -342,9 +332,7 @@ public class NeuralNetConfiguration implements Serializable, Cloneable {
ObjectMapper mapper = mapper();
try {
- String ret = mapper.writeValueAsString(this);
- return ret;
-
+ return mapper.writeValueAsString(this);
} catch (org.nd4j.shade.jackson.core.JsonProcessingException e) {
throw new RuntimeException(e);
}
@@ -384,86 +372,6 @@ public class NeuralNetConfiguration implements Serializable, Cloneable {
return JsonMappers.getMapper();
}
- /**
- * Set of classes that can be registered for legacy deserialization.
- */
- private static List> REGISTERABLE_CUSTOM_CLASSES = (List>) Arrays.>asList(
- Layer.class,
- GraphVertex.class,
- InputPreProcessor.class,
- IActivation.class,
- ILossFunction.class,
- ReconstructionDistribution.class
- );
-
- /**
- * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
- * ONLY) for JSON deserialization.
- *
- * This is required ONLY when BOTH of the following conditions are met:
- * 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND
- * 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J)
- *
- * By passing the classes of these layers here, DL4J should be able to deserialize them, in spite of the JSON
- * format change between versions.
- *
- * @param classes Classes to register
- */
- public static void registerLegacyCustomClassesForJSON(Class>... classes) {
- registerLegacyCustomClassesForJSONList(Arrays.>asList(classes));
- }
-
- /**
- * @see #registerLegacyCustomClassesForJSON(Class[])
- */
- public static void registerLegacyCustomClassesForJSONList(List> classes){
- //Default names (i.e., old format for custom JSON format)
- List> list = new ArrayList<>();
- for(Class> c : classes){
- list.add(new Pair(c.getSimpleName(), c));
- }
- registerLegacyCustomClassesForJSON(list);
- }
-
- /**
- * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
- * ONLY) for JSON deserialization, with custom names.
- * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])}
- * but is added in case it is required in non-standard circumstances.
- */
- public static void registerLegacyCustomClassesForJSON(List> classes){
- for(Pair p : classes){
- String s = p.getFirst();
- Class c = p.getRight();
- //Check if it's a valid class to register...
- boolean found = false;
- for( Class> c2 : REGISTERABLE_CUSTOM_CLASSES){
- if(c2.isAssignableFrom(c)){
- if(c2 == Layer.class){
- LegacyLayerDeserializer.registerLegacyClassSpecifiedName(s, (Class extends Layer>)c);
- } else if(c2 == GraphVertex.class){
- LegacyGraphVertexDeserializer.registerLegacyClassSpecifiedName(s, (Class extends GraphVertex>)c);
- } else if(c2 == InputPreProcessor.class){
- LegacyPreprocessorDeserializer.registerLegacyClassSpecifiedName(s, (Class extends InputPreProcessor>)c);
- } else if(c2 == IActivation.class ){
- LegacyIActivationDeserializer.registerLegacyClassSpecifiedName(s, (Class extends IActivation>)c);
- } else if(c2 == ILossFunction.class ){
- LegacyILossFunctionDeserializer.registerLegacyClassSpecifiedName(s, (Class extends ILossFunction>)c);
- } else if(c2 == ReconstructionDistribution.class){
- LegacyReconstructionDistributionDeserializer.registerLegacyClassSpecifiedName(s, (Class extends ReconstructionDistribution>)c);
- }
-
- found = true;
- }
- }
-
- if(!found){
- throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " +
- c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
- }
- }
- }
-
/**
* NeuralNetConfiguration builder, used as a starting point for creating a MultiLayerConfiguration or
* ComputationGraphConfiguration.
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java
index 2cd8d83f4..6ca3b35de 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java
@@ -15,7 +15,7 @@
******************************************************************************/
package org.deeplearning4j.nn.conf.graph;
-import com.google.common.base.Preconditions;
+import org.nd4j.shade.guava.base.Preconditions;
import lombok.*;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.inputs.InputType;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java
index e4968a49e..497cc77f5 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java
@@ -19,7 +19,6 @@ package org.deeplearning4j.nn.conf.graph;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializerHelper;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -33,8 +32,7 @@ import java.io.Serializable;
*
* @author Alex Black
*/
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyGraphVertexDeserializerHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public abstract class GraphVertex implements Cloneable, Serializable {
@Override
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java
index 805e1729c..85da86fa2 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java
@@ -16,12 +16,15 @@
package org.deeplearning4j.nn.conf.inputs;
-import lombok.*;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.annotation.JsonInclude;
-import org.nd4j.shade.jackson.annotation.JsonSubTypes;
+import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.Serializable;
@@ -36,12 +39,7 @@ import java.util.Arrays;
* @author Alex Black
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
-@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
-@JsonSubTypes(value = {@JsonSubTypes.Type(value = InputType.InputTypeFeedForward.class, name = "FeedForward"),
- @JsonSubTypes.Type(value = InputType.InputTypeRecurrent.class, name = "Recurrent"),
- @JsonSubTypes.Type(value = InputType.InputTypeConvolutional.class, name = "Convolutional"),
- @JsonSubTypes.Type(value = InputType.InputTypeConvolutionalFlat.class, name = "ConvolutionalFlat"),
- @JsonSubTypes.Type(value = InputType.InputTypeConvolutional3D.class, name = "Convolutional3D")})
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public abstract class InputType implements Serializable {
/**
@@ -174,13 +172,16 @@ public abstract class InputType implements Serializable {
}
- @AllArgsConstructor
- @Getter
@NoArgsConstructor
+ @Getter
@EqualsAndHashCode(callSuper = false)
public static class InputTypeFeedForward extends InputType {
private long size;
+ public InputTypeFeedForward(@JsonProperty("size") long size) {
+ this.size = size;
+ }
+
@Override
public Type getType() {
return Type.FF;
@@ -203,9 +204,8 @@ public abstract class InputType implements Serializable {
}
}
- @Getter
@NoArgsConstructor
- @AllArgsConstructor
+ @Getter
@EqualsAndHashCode(callSuper = false)
public static class InputTypeRecurrent extends InputType {
private long size;
@@ -215,6 +215,11 @@ public abstract class InputType implements Serializable {
this(size, -1);
}
+ public InputTypeRecurrent(@JsonProperty("size") long size, @JsonProperty("timeSeriesLength") long timeSeriesLength) {
+ this.size = size;
+ this.timeSeriesLength = timeSeriesLength;
+ }
+
@Override
public Type getType() {
return Type.RNN;
@@ -245,15 +250,19 @@ public abstract class InputType implements Serializable {
}
}
- @AllArgsConstructor
+ @NoArgsConstructor
@Data
@EqualsAndHashCode(callSuper = false)
- @NoArgsConstructor
public static class InputTypeConvolutional extends InputType {
private long height;
private long width;
private long channels;
+ public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) {
+ this.height = height;
+ this.width = width;
+ this.channels = channels;
+ }
/**
* Return the number of channels / depth for this 2D convolution. This method has been deprecated,
@@ -298,10 +307,9 @@ public abstract class InputType implements Serializable {
}
}
- @AllArgsConstructor
+ @NoArgsConstructor
@Data
@EqualsAndHashCode(callSuper = false)
- @NoArgsConstructor
public static class InputTypeConvolutional3D extends InputType {
private Convolution3D.DataFormat dataFormat;
private long depth;
@@ -309,6 +317,15 @@ public abstract class InputType implements Serializable {
private long width;
private long channels;
+ public InputTypeConvolutional3D(@JsonProperty("dataFormat") Convolution3D.DataFormat dataFormat,
+ @JsonProperty("depth") long depth, @JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) {
+ this.dataFormat = dataFormat;
+ this.depth = depth;
+ this.height = height;
+ this.width = width;
+ this.channels = channels;
+ }
+
@Override
public Type getType() {
return Type.CNN3D;
@@ -336,15 +353,20 @@ public abstract class InputType implements Serializable {
}
}
- @AllArgsConstructor
+ @NoArgsConstructor
@Data
@EqualsAndHashCode(callSuper = false)
- @NoArgsConstructor
public static class InputTypeConvolutionalFlat extends InputType {
private long height;
private long width;
private long depth;
+ public InputTypeConvolutionalFlat(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("depth") long depth) {
+ this.height = height;
+ this.width = width;
+ this.depth = depth;
+ }
+
@Override
public Type getType() {
return Type.CNNFlat;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java
index 5dfb3b671..25577bd1f 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java
@@ -29,7 +29,6 @@ import org.deeplearning4j.nn.conf.dropout.Dropout;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializerHelper;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -45,8 +44,7 @@ import java.util.*;
* A neural network layer.
*/
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyLayerDeserializerHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Data
@NoArgsConstructor
public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java
index 1e5eb11f2..0f1a770a8 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java
@@ -22,7 +22,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyIntArrayDeserializer;
+import org.deeplearning4j.nn.conf.serde.legacy.LegacyIntArrayDeserializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.linalg.api.buffer.DataType;
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java
index 7b2317a8f..f72da09e5 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java
@@ -27,7 +27,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
-import org.deeplearning4j.nn.conf.serde.FrozenLayerDeserializer;
import org.deeplearning4j.nn.params.FrozenLayerParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
@@ -48,7 +47,6 @@ import java.util.List;
* @author Alex Black
*/
@EqualsAndHashCode(callSuper = false)
-@JsonDeserialize(using = FrozenLayerDeserializer.class)
public class FrozenLayer extends Layer {
@Getter
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java
index 0de9143fe..f58e6b0e7 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java
@@ -16,7 +16,6 @@
package org.deeplearning4j.nn.conf.layers.variational;
-import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyReconstructionDistributionDeserializerHelper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@@ -32,8 +31,7 @@ import java.io.Serializable;
*
* @author Alex Black
*/
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
- defaultImpl = LegacyReconstructionDistributionDeserializerHelper.class)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
public interface ReconstructionDistribution extends Serializable {
/**
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java
index 5194da4a1..d32488363 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java
@@ -21,12 +21,18 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
+import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.*;
+import org.nd4j.linalg.activations.Activation;
+import org.nd4j.linalg.activations.IActivation;
+import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
+import org.nd4j.linalg.lossfunctions.ILossFunction;
+import org.nd4j.linalg.lossfunctions.impl.*;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
@@ -38,6 +44,8 @@ import org.nd4j.shade.jackson.databind.node.ObjectNode;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
/**
* A custom (abstract) deserializer that handles backward compatibility (currently only for updater refactoring that
@@ -103,6 +111,24 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im
return false;
}
+ protected boolean requiresActivationFromLegacy(Layer[] layers){
+ for(Layer l : layers){
+ if(l instanceof BaseLayer && ((BaseLayer)l).getActivationFn() == null){
+ return true;
+ }
+ }
+ return false;
+ }
+
+ protected boolean requiresLegacyLossHandling(Layer[] layers){
+ for(Layer l : layers){
+ if(l instanceof BaseOutputLayer && ((BaseOutputLayer)l).getLossFn() == null){
+ return true;
+ }
+ }
+ return false;
+ }
+
protected void handleUpdaterBackwardCompatibility(BaseLayer layer, ObjectNode on){
if(on != null && on.has("updater")){
String updaterName = on.get("updater").asText();
@@ -220,7 +246,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im
}
protected void handleWeightInitBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){
- if(on != null && (on.has("weightInit") )){
+ if(on != null && on.has("weightInit") ){
//Legacy format JSON
if(on.has("weightInit")){
String wi = on.get("weightInit").asText();
@@ -228,8 +254,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im
WeightInit w = WeightInit.valueOf(wi);
Distribution d = null;
if(w == WeightInit.DISTRIBUTION && on.has("dist")){
- //TODO deserialize distribution
- String dist = on.get("dist").asText();
+ String dist = on.get("dist").toString();
d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class);
}
IWeightInit iwi = w.getWeightInitFunction(d);
@@ -241,6 +266,57 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer