Version upgrades (#199)
* DataVec fixes for Jackson version upgrade Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J jackson updates + databind version 2.9.9.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Shade snakeyaml along with jackson Signed-off-by: AlexDBlack <blacka101@gmail.com> * Version fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch DataVec legacy JSON format handling to mixins Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next set of fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup for legacy JSON mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade commons compress to 1.18; small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes, all but legacy custom passing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format Signed-off-by: AlexDBlack <blacka101@gmail.com> * Legacy deserialization cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small amount of polish - legacy JSON Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade guava version Signed-off-by: AlexDBlack <blacka101@gmail.com> * IEvaluation legacy format deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade play version to 2.7.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update nd4j-parameter-server-status to new Play API Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update DL4J UI for new play version Signed-off-by: AlexDBlack <blacka101@gmail.com> * More play framework updates Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove Spark 1/2 adapter code from DataVec Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec-spark dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 4 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Another fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Breeze upgrade, dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add Scala 2.12 version to pom.xml Signed-off-by: AlexDBlack <blacka101@gmail.com> * change-scala-versions.sh - add scala 2.12, remove 2.10 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Move Spark version properties to parent pom (now that only one spark version is supported) Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataVec Play fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec play dependency fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up old spark/jackson stuff Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup jackson unused dependencies Signed-off-by: AlexDBlack <blacka101@gmail.com> * Dropping redundant dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Removed scalaxy dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * DataVec fixes for Jackson version upgrade Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J jackson updates + databind version 2.9.9.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Shade snakeyaml along with jackson Signed-off-by: AlexDBlack <blacka101@gmail.com> * Version fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch DataVec legacy JSON format handling to mixins Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next set of fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup for legacy JSON mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade commons compress to 1.18; small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes, all but legacy custom passing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format Signed-off-by: AlexDBlack <blacka101@gmail.com> * Legacy deserialization cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small amount of polish - legacy JSON Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade guava version Signed-off-by: AlexDBlack <blacka101@gmail.com> * IEvaluation legacy format deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade play version to 2.7.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update nd4j-parameter-server-status to new Play API Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update DL4J UI for new play version Signed-off-by: AlexDBlack <blacka101@gmail.com> * More play framework updates Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove Spark 1/2 adapter code from DataVec Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec-spark dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 4 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Another fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Breeze upgrade, dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add Scala 2.12 version to pom.xml Signed-off-by: AlexDBlack <blacka101@gmail.com> * change-scala-versions.sh - add scala 2.12, remove 2.10 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Move Spark version properties to parent pom (now that only one spark version is supported) Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataVec Play fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec play dependency fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up old spark/jackson stuff Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup jackson unused dependencies Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add shaded guava Signed-off-by: AlexDBlack <blacka101@gmail.com> * Dropping redundant dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Removed scalaxy dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Ensure not possible to import pre-shaded classes, and remove direct guava dependencies in favor of shaded Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J Shaded guava import fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataVec and DL4J guava shading Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter, RL4J fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Build fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Fix dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Fix bad merge Signed-off-by: AlexDBlack <blacka101@gmail.com> * Jackson shading fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Set play secret, datavec-spark-inference-server Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for datavec-spark-inference-server Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
378669cc10
commit
dcc2baa676
|
@ -72,6 +72,11 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>joda-time</groupId>
|
||||||
|
<artifactId>joda-time</artifactId>
|
||||||
|
<version>${jodatime.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<!-- ND4J Shaded Jackson Dependency -->
|
<!-- ND4J Shaded Jackson Dependency -->
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.apache.commons.math3.distribution.RealDistribution;
|
import org.apache.commons.math3.distribution.RealDistribution;
|
||||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner;
|
package org.deeplearning4j.arbiter.optimize.runner;
|
||||||
|
|
||||||
import com.google.common.util.concurrent.ListenableFuture;
|
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner;
|
package org.deeplearning4j.arbiter.optimize.runner;
|
||||||
|
|
||||||
import com.google.common.util.concurrent.ListenableFuture;
|
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
||||||
import com.google.common.util.concurrent.ListeningExecutorService;
|
import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService;
|
||||||
import com.google.common.util.concurrent.MoreExecutors;
|
import org.nd4j.shade.guava.util.concurrent.MoreExecutors;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.*;
|
import org.deeplearning4j.arbiter.optimize.api.*;
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||||
|
|
|
@ -43,13 +43,15 @@ public class JsonMapper {
|
||||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||||
|
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||||
yamlMapper = new ObjectMapper(new YAMLFactory());
|
yamlMapper = new ObjectMapper(new YAMLFactory());
|
||||||
mapper.registerModule(new JodaModule());
|
yamlMapper.registerModule(new JodaModule());
|
||||||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
yamlMapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||||
|
yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||||
}
|
}
|
||||||
|
|
||||||
private JsonMapper() {}
|
private JsonMapper() {}
|
||||||
|
|
|
@ -39,6 +39,7 @@ public class YamlMapper {
|
||||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||||
|
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -59,6 +59,7 @@ public class TestJson {
|
||||||
om.enable(SerializationFeature.INDENT_OUTPUT);
|
om.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||||
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||||
|
om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||||
return om;
|
return om;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,12 @@
|
||||||
<artifactId>jackson</artifactId>
|
<artifactId>jackson</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.code.gson</groupId>
|
||||||
|
<artifactId>gson</artifactId>
|
||||||
|
<version>${gson.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.layers;
|
package org.deeplearning4j.arbiter.layers;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
|
@ -107,12 +107,6 @@
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.guava</groupId>
|
|
||||||
<artifactId>guava</artifactId>
|
|
||||||
<version>${guava.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<artifactId>arbiter-core</artifactId>
|
<artifactId>arbiter-core</artifactId>
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.ui.misc;
|
||||||
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||||
|
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||||
|
@ -45,12 +46,9 @@ public class JsonMapper {
|
||||||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||||
mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
|
mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
|
||||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
|
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||||
mapper.setVisibilityChecker(mapper.getSerializationConfig().getDefaultVisibilityChecker()
|
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||||
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
|
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||||
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
|
|
||||||
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
|
|
||||||
.withCreatorVisibility(JsonAutoDetect.Visibility.NONE));
|
|
||||||
|
|
||||||
return mapper;
|
return mapper;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
VALID_VERSIONS=( 2.10 2.11 )
|
VALID_VERSIONS=( 2.11 2.12 )
|
||||||
SCALA_210_VERSION=$(grep -F -m 1 'scala210.version' pom.xml); SCALA_210_VERSION="${SCALA_210_VERSION#*>}"; SCALA_210_VERSION="${SCALA_210_VERSION%<*}";
|
|
||||||
SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}";
|
SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}";
|
||||||
|
SCALA_212_VERSION=$(grep -F -m 1 'scala212.version' pom.xml); SCALA_212_VERSION="${SCALA_212_VERSION#*>}"; SCALA_212_VERSION="${SCALA_212_VERSION%<*}";
|
||||||
|
|
||||||
usage() {
|
usage() {
|
||||||
echo "Usage: $(basename $0) [-h|--help] <scala version to be used>
|
echo "Usage: $(basename $0) [-h|--help] <scala version to be used>
|
||||||
|
@ -45,19 +45,18 @@ check_scala_version() {
|
||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
check_scala_version "$TO_VERSION"
|
check_scala_version "$TO_VERSION"
|
||||||
|
|
||||||
if [ $TO_VERSION = "2.11" ]; then
|
if [ $TO_VERSION = "2.11" ]; then
|
||||||
FROM_BINARY="_2\.10"
|
FROM_BINARY="_2\.12"
|
||||||
TO_BINARY="_2\.11"
|
TO_BINARY="_2\.11"
|
||||||
FROM_VERSION=$SCALA_210_VERSION
|
FROM_VERSION=$SCALA_212_VERSION
|
||||||
TO_VERSION=$SCALA_211_VERSION
|
TO_VERSION=$SCALA_211_VERSION
|
||||||
else
|
else
|
||||||
FROM_BINARY="_2\.11"
|
FROM_BINARY="_2\.11"
|
||||||
TO_BINARY="_2\.10"
|
TO_BINARY="_2\.12"
|
||||||
FROM_VERSION=$SCALA_211_VERSION
|
FROM_VERSION=$SCALA_211_VERSION
|
||||||
TO_VERSION=$SCALA_210_VERSION
|
TO_VERSION=$SCALA_212_VERSION
|
||||||
fi
|
fi
|
||||||
|
|
||||||
sed_i() {
|
sed_i() {
|
||||||
|
@ -70,35 +69,24 @@ echo "Updating Scala versions in pom.xml files to Scala $1, from $FROM_VERSION t
|
||||||
|
|
||||||
BASEDIR=$(dirname $0)
|
BASEDIR=$(dirname $0)
|
||||||
|
|
||||||
#Artifact ids, ending with "_2.10" or "_2.11". Spark, spark-mllib, kafka, etc.
|
#Artifact ids, ending with "_2.11" or "_2.12". Spark, spark-mllib, kafka, etc.
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||||
-exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \;
|
-exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \;
|
||||||
|
|
||||||
#Scala versions, like <scala.version>2.10</scala.version>
|
#Scala versions, like <scala.version>2.11</scala.version>
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||||
-exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \;
|
-exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \;
|
||||||
|
|
||||||
#Scala binary versions, like <scala.binary.version>2.10</scala.binary.version>
|
#Scala binary versions, like <scala.binary.version>2.11</scala.binary.version>
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||||
-exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \;
|
-exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \;
|
||||||
|
|
||||||
#Scala versions, like <artifactId>scala-library</artifactId> <version>2.10.6</version>
|
#Scala versions, like <artifactId>scala-library</artifactId> <version>2.11.12</version>
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||||
-exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \;
|
-exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \;
|
||||||
|
|
||||||
#Scala maven plugin, <scalaVersion>2.10</scalaVersion>
|
#Scala maven plugin, <scalaVersion>2.11</scalaVersion>
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||||
-exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \;
|
-exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \;
|
||||||
|
|
||||||
#Edge case for Korean NLP artifact not following conventions: https://github.com/deeplearning4j/deeplearning4j/issues/6306
|
|
||||||
#https://github.com/deeplearning4j/deeplearning4j/issues/6306
|
|
||||||
if [[ $TO_VERSION == 2.11* ]]; then
|
|
||||||
sed_i 's/<artifactId>korean-text-scala-2.10<\/artifactId>/<artifactId>korean-text<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
|
||||||
sed_i 's/<version>4.2.0<\/version>/<version>4.4<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
|
||||||
else
|
|
||||||
sed_i 's/<artifactId>korean-text<\/artifactId>/<artifactId>korean-text-scala-2.10<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
|
||||||
sed_i 's/<version>4.4<\/version>/<version>4.2.0<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
echo "Done updating Scala versions.";
|
echo "Done updating Scala versions.";
|
||||||
|
|
|
@ -1,85 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
#
|
|
||||||
# This program and the accompanying materials are made available under the
|
|
||||||
# terms of the Apache License, Version 2.0 which is available at
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
# This shell script is adapted from Apache Flink (in turn, adapted from Apache Spark) some modifications.
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
VALID_VERSIONS=( 1 2 )
|
|
||||||
SPARK_2_VERSION="2\.1\.0"
|
|
||||||
SPARK_1_VERSION="1\.6\.3"
|
|
||||||
|
|
||||||
usage() {
|
|
||||||
echo "Usage: $(basename $0) [-h|--help] <spark version to be used>
|
|
||||||
where :
|
|
||||||
-h| --help Display this help text
|
|
||||||
valid spark version values : ${VALID_VERSIONS[*]}
|
|
||||||
" 1>&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then
|
|
||||||
usage
|
|
||||||
fi
|
|
||||||
|
|
||||||
TO_VERSION=$1
|
|
||||||
|
|
||||||
check_spark_version() {
|
|
||||||
for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done
|
|
||||||
echo "Invalid Spark version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
check_spark_version "$TO_VERSION"
|
|
||||||
|
|
||||||
if [ $TO_VERSION = "2" ]; then
|
|
||||||
FROM_BINARY="1"
|
|
||||||
TO_BINARY="2"
|
|
||||||
FROM_VERSION=$SPARK_1_VERSION
|
|
||||||
TO_VERSION=$SPARK_2_VERSION
|
|
||||||
else
|
|
||||||
FROM_BINARY="2"
|
|
||||||
TO_BINARY="1"
|
|
||||||
FROM_VERSION=$SPARK_2_VERSION
|
|
||||||
TO_VERSION=$SPARK_1_VERSION
|
|
||||||
fi
|
|
||||||
|
|
||||||
sed_i() {
|
|
||||||
sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2"
|
|
||||||
}
|
|
||||||
|
|
||||||
export -f sed_i
|
|
||||||
|
|
||||||
echo "Updating Spark versions in pom.xml files to Spark $1";
|
|
||||||
|
|
||||||
BASEDIR=$(dirname $0)
|
|
||||||
|
|
||||||
# <spark.major.version>1</spark.major.version>
|
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
|
||||||
-exec bash -c "sed_i 's/\(spark.major.version>\)'$FROM_BINARY'<\/spark.major.version>/\1'$TO_BINARY'<\/spark.major.version>/g' {}" \;
|
|
||||||
|
|
||||||
# <spark.version>1.6.3</spark.version>
|
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
|
||||||
-exec bash -c "sed_i 's/\(spark.version>\)'$FROM_VERSION'<\/spark.version>/\1'$TO_VERSION'<\/spark.version>/g' {}" \;
|
|
||||||
|
|
||||||
#Spark versions, like <version>xxx_spark_2xxx</version> OR <datavec.spark.version>xxx_spark_2xxx</datavec.spark.version>
|
|
||||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
|
||||||
-exec bash -c "sed_i 's/\(version>.*_spark_\)'$FROM_BINARY'\(.*\)version>/\1'$TO_BINARY'\2version>/g' {}" \;
|
|
||||||
|
|
||||||
echo "Done updating Spark versions.";
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package org.datavec.api.transform;
|
package org.datavec.api.transform;
|
||||||
|
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -27,8 +26,7 @@ import java.util.List;
|
||||||
/**A Transform converts an example to another example, or a sequence to another sequence
|
/**A Transform converts an example to another example, or a sequence to another sequence
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.TransformHelper.class)
|
|
||||||
public interface Transform extends Serializable, ColumnOp {
|
public interface Transform extends Serializable, ColumnOp {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -67,6 +67,7 @@ import org.joda.time.DateTimeZone;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
@ -417,6 +418,16 @@ public class TransformProcess implements Serializable {
|
||||||
public static TransformProcess fromJson(String json) {
|
public static TransformProcess fromJson(String json) {
|
||||||
try {
|
try {
|
||||||
return JsonMappers.getMapper().readValue(json, TransformProcess.class);
|
return JsonMappers.getMapper().readValue(json, TransformProcess.class);
|
||||||
|
} catch (InvalidTypeIdException e){
|
||||||
|
if(e.getMessage().contains("@class")){
|
||||||
|
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||||
|
try{
|
||||||
|
return JsonMappers.getLegacyMapper().readValue(json, TransformProcess.class);
|
||||||
|
} catch (IOException e2){
|
||||||
|
throw new RuntimeException(e2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new RuntimeException(e);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
//TODO proper exception message
|
//TODO proper exception message
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
|
|
@ -23,12 +23,14 @@ import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
|
||||||
import org.datavec.api.transform.metadata.CategoricalMetaData;
|
import org.datavec.api.transform.metadata.CategoricalMetaData;
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
|
import org.datavec.api.transform.serde.JsonMappers;
|
||||||
import org.datavec.api.transform.serde.JsonSerializer;
|
import org.datavec.api.transform.serde.JsonSerializer;
|
||||||
import org.datavec.api.transform.serde.YamlSerializer;
|
import org.datavec.api.transform.serde.YamlSerializer;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||||
import org.nd4j.shade.jackson.databind.node.ArrayNode;
|
import org.nd4j.shade.jackson.databind.node.ArrayNode;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -116,6 +118,16 @@ public class DataAnalysis implements Serializable {
|
||||||
public static DataAnalysis fromJson(String json) {
|
public static DataAnalysis fromJson(String json) {
|
||||||
try{
|
try{
|
||||||
return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class);
|
return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class);
|
||||||
|
} catch (InvalidTypeIdException e){
|
||||||
|
if(e.getMessage().contains("@class")){
|
||||||
|
try{
|
||||||
|
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||||
|
return JsonMappers.getLegacyMapper().readValue(json, DataAnalysis.class);
|
||||||
|
} catch (IOException e2){
|
||||||
|
throw new RuntimeException(e2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new RuntimeException(e);
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
//Legacy format
|
//Legacy format
|
||||||
ObjectMapper om = new JsonSerializer().getObjectMapper();
|
ObjectMapper om = new JsonSerializer().getObjectMapper();
|
||||||
|
|
|
@ -21,9 +21,10 @@ import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
|
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
|
||||||
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
|
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
|
import org.datavec.api.transform.serde.JsonMappers;
|
||||||
import org.datavec.api.transform.serde.JsonSerializer;
|
import org.datavec.api.transform.serde.JsonSerializer;
|
||||||
import org.datavec.api.transform.serde.YamlSerializer;
|
import org.datavec.api.transform.serde.YamlSerializer;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -50,6 +51,16 @@ public class SequenceDataAnalysis extends DataAnalysis {
|
||||||
public static SequenceDataAnalysis fromJson(String json){
|
public static SequenceDataAnalysis fromJson(String json){
|
||||||
try{
|
try{
|
||||||
return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class);
|
return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class);
|
||||||
|
} catch (InvalidTypeIdException e){
|
||||||
|
if(e.getMessage().contains("@class")){
|
||||||
|
try{
|
||||||
|
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||||
|
return JsonMappers.getLegacyMapper().readValue(json, SequenceDataAnalysis.class);
|
||||||
|
} catch (IOException e2){
|
||||||
|
throw new RuntimeException(e2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new RuntimeException(e);
|
||||||
} catch (IOException e){
|
} catch (IOException e){
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.datavec.api.transform.analysis.columns;
|
package org.datavec.api.transform.analysis.columns;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
|
@ -27,8 +26,7 @@ import java.io.Serializable;
|
||||||
* Interface for column analysis
|
* Interface for column analysis
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.ColumnAnalysisHelper.class)
|
|
||||||
public interface ColumnAnalysis extends Serializable {
|
public interface ColumnAnalysis extends Serializable {
|
||||||
|
|
||||||
long getCountTotal();
|
long getCountTotal();
|
||||||
|
|
|
@ -18,7 +18,6 @@ package org.datavec.api.transform.condition;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnOp;
|
import org.datavec.api.transform.ColumnOp;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -35,8 +34,7 @@ import java.util.List;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.ConditionHelper.class)
|
|
||||||
public interface Condition extends Serializable, ColumnOp {
|
public interface Condition extends Serializable, ColumnOp {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -18,7 +18,6 @@ package org.datavec.api.transform.filter;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnOp;
|
import org.datavec.api.transform.ColumnOp;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -33,8 +32,7 @@ import java.util.List;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.FilterHelper.class)
|
|
||||||
public interface Filter extends Serializable, ColumnOp {
|
public interface Filter extends Serializable, ColumnOp {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.datavec.api.transform.metadata;
|
package org.datavec.api.transform.metadata;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -32,8 +31,7 @@ import java.io.Serializable;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.ColumnMetaDataHelper.class)
|
|
||||||
public interface ColumnMetaData extends Serializable, Cloneable {
|
public interface ColumnMetaData extends Serializable, Cloneable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -23,8 +23,8 @@ import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static com.google.common.base.Preconditions.checkArgument;
|
import static org.nd4j.shade.guava.base.Preconditions.checkArgument;
|
||||||
import static com.google.common.base.Preconditions.checkNotNull;
|
import static org.nd4j.shade.guava.base.Preconditions.checkNotNull;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition},
|
* A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition},
|
||||||
|
|
|
@ -23,7 +23,6 @@ import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.datavec.api.transform.metadata.LongMetaData;
|
import org.datavec.api.transform.metadata.LongMetaData;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
import org.datavec.api.transform.schema.SequenceSchema;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.comparator.WritableComparator;
|
import org.datavec.api.writable.comparator.WritableComparator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
|
@ -50,8 +49,7 @@ import java.util.List;
|
||||||
@EqualsAndHashCode(exclude = {"inputSchema"})
|
@EqualsAndHashCode(exclude = {"inputSchema"})
|
||||||
@JsonIgnoreProperties({"inputSchema"})
|
@JsonIgnoreProperties({"inputSchema"})
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.CalculateSortedRankHelper.class)
|
|
||||||
public class CalculateSortedRank implements Serializable, ColumnOp {
|
public class CalculateSortedRank implements Serializable, ColumnOp {
|
||||||
|
|
||||||
private final String newColumnName;
|
private final String newColumnName;
|
||||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.metadata.*;
|
import org.datavec.api.transform.metadata.*;
|
||||||
import org.datavec.api.transform.serde.JsonMappers;
|
import org.datavec.api.transform.serde.JsonMappers;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
import org.joda.time.DateTimeZone;
|
import org.joda.time.DateTimeZone;
|
||||||
import org.nd4j.shade.jackson.annotation.*;
|
import org.nd4j.shade.jackson.annotation.*;
|
||||||
|
@ -29,9 +28,11 @@ import org.nd4j.shade.jackson.core.JsonFactory;
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
|
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -48,8 +49,7 @@ import java.util.*;
|
||||||
*/
|
*/
|
||||||
@JsonIgnoreProperties({"columnNames", "columnNamesIndex"})
|
@JsonIgnoreProperties({"columnNames", "columnNamesIndex"})
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.SchemaHelper.class)
|
|
||||||
@Data
|
@Data
|
||||||
public class Schema implements Serializable {
|
public class Schema implements Serializable {
|
||||||
|
|
||||||
|
@ -358,6 +358,16 @@ public class Schema implements Serializable {
|
||||||
public static Schema fromJson(String json) {
|
public static Schema fromJson(String json) {
|
||||||
try{
|
try{
|
||||||
return JsonMappers.getMapper().readValue(json, Schema.class);
|
return JsonMappers.getMapper().readValue(json, Schema.class);
|
||||||
|
} catch (InvalidTypeIdException e){
|
||||||
|
if(e.getMessage().contains("@class")){
|
||||||
|
try{
|
||||||
|
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||||
|
return JsonMappers.getLegacyMapper().readValue(json, Schema.class);
|
||||||
|
} catch (IOException e2){
|
||||||
|
throw new RuntimeException(e2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new RuntimeException(e);
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
//TODO better exceptions
|
//TODO better exceptions
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
|
@ -379,21 +389,6 @@ public class Schema implements Serializable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Schema fromJacksonString(String str, JsonFactory factory) {
|
|
||||||
ObjectMapper om = new ObjectMapper(factory);
|
|
||||||
om.registerModule(new JodaModule());
|
|
||||||
om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
|
||||||
om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
|
||||||
om.enable(SerializationFeature.INDENT_OUTPUT);
|
|
||||||
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
|
||||||
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
|
||||||
try {
|
|
||||||
return om.readValue(str, Schema.class);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
List<ColumnMetaData> columnMetaData = new ArrayList<>();
|
List<ColumnMetaData> columnMetaData = new ArrayList<>();
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.datavec.api.transform.sequence;
|
package org.datavec.api.transform.sequence;
|
||||||
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -30,8 +29,7 @@ import java.util.List;
|
||||||
* Compare the time steps of a sequence
|
* Compare the time steps of a sequence
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.SequenceComparatorHelper.class)
|
|
||||||
public interface SequenceComparator extends Comparator<List<Writable>>, Serializable {
|
public interface SequenceComparator extends Comparator<List<Writable>>, Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.datavec.api.transform.sequence;
|
package org.datavec.api.transform.sequence;
|
||||||
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -32,8 +31,7 @@ import java.util.List;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.SequenceSplitHelper.class)
|
|
||||||
public interface SequenceSplit extends Serializable {
|
public interface SequenceSplit extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.datavec.api.transform.sequence.window;
|
package org.datavec.api.transform.sequence.window;
|
||||||
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -36,8 +35,7 @@ import java.util.List;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.WindowFunctionHelper.class)
|
|
||||||
public interface WindowFunction extends Serializable {
|
public interface WindowFunction extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -16,44 +16,17 @@
|
||||||
|
|
||||||
package org.datavec.api.transform.serde;
|
package org.datavec.api.transform.serde;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.datavec.api.io.WritableComparator;
|
import org.datavec.api.transform.serde.legacy.LegacyJsonFormat;
|
||||||
import org.datavec.api.transform.Transform;
|
|
||||||
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
|
|
||||||
import org.datavec.api.transform.condition.column.ColumnCondition;
|
|
||||||
import org.datavec.api.transform.filter.Filter;
|
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
|
||||||
import org.datavec.api.transform.rank.CalculateSortedRank;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.sequence.SequenceComparator;
|
|
||||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
|
||||||
import org.datavec.api.transform.sequence.window.WindowFunction;
|
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.serde.json.LegacyIActivationDeserializer;
|
|
||||||
import org.nd4j.serde.json.LegacyILossFunctionDeserializer;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
||||||
import org.nd4j.shade.jackson.databind.*;
|
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.cfg.MapperConfig;
|
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||||
import org.nd4j.shade.jackson.databind.introspect.Annotated;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
|
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.introspect.AnnotationMap;
|
|
||||||
import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector;
|
|
||||||
import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder;
|
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||||
|
|
||||||
import java.lang.annotation.Annotation;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* JSON mappers for deserializing neural net configurations, etc.
|
* JSON mappers for deserializing neural net configurations, etc.
|
||||||
*
|
*
|
||||||
|
@ -62,38 +35,9 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JsonMappers {
|
public class JsonMappers {
|
||||||
|
|
||||||
/**
|
|
||||||
* This system property is provided as an alternative to {@link #registerLegacyCustomClassesForJSON(Class[])}
|
|
||||||
* Classes can be specified in comma-separated format
|
|
||||||
*/
|
|
||||||
public static String CUSTOM_REGISTRATION_PROPERTY = "org.datavec.config.custom.legacyclasses";
|
|
||||||
|
|
||||||
static {
|
|
||||||
String p = System.getProperty(CUSTOM_REGISTRATION_PROPERTY);
|
|
||||||
if(p != null && !p.isEmpty()){
|
|
||||||
String[] split = p.split(",");
|
|
||||||
List<Class<?>> list = new ArrayList<>();
|
|
||||||
for(String s : split){
|
|
||||||
try{
|
|
||||||
Class<?> c = Class.forName(s);
|
|
||||||
list.add(c);
|
|
||||||
} catch (Throwable t){
|
|
||||||
log.warn("Error parsing {} system property: class \"{}\" could not be loaded",CUSTOM_REGISTRATION_PROPERTY, s, t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(list.size() > 0){
|
|
||||||
try {
|
|
||||||
registerLegacyCustomClassesForJSONList(list);
|
|
||||||
} catch (Throwable t){
|
|
||||||
log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",CUSTOM_REGISTRATION_PROPERTY, t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static ObjectMapper jsonMapper;
|
private static ObjectMapper jsonMapper;
|
||||||
private static ObjectMapper yamlMapper;
|
private static ObjectMapper yamlMapper;
|
||||||
|
private static ObjectMapper legacyMapper; //For 1.0.0-alpha and earlier TransformProcess etc
|
||||||
|
|
||||||
static {
|
static {
|
||||||
jsonMapper = new ObjectMapper();
|
jsonMapper = new ObjectMapper();
|
||||||
|
@ -102,117 +46,12 @@ public class JsonMappers {
|
||||||
configureMapper(yamlMapper);
|
configureMapper(yamlMapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Map<Class, ObjectMapper> legacyMappers = new ConcurrentHashMap<>();
|
public static synchronized ObjectMapper getLegacyMapper(){
|
||||||
|
if(legacyMapper == null){
|
||||||
|
legacyMapper = LegacyJsonFormat.legacyMapper();
|
||||||
/**
|
configureMapper(legacyMapper);
|
||||||
* Register a set of classes (Transform, Filter, etc) for JSON deserialization.<br>
|
|
||||||
* <br>
|
|
||||||
* This is required ONLY when BOTH of the following conditions are met:<br>
|
|
||||||
* 1. You want to load a serialized TransformProcess, saved in 1.0.0-alpha or before, AND<br>
|
|
||||||
* 2. The serialized TransformProcess has a custom Transform, Filter, etc (i.e., one not defined in DL4J)<br>
|
|
||||||
* <br>
|
|
||||||
* By passing the classes of these custom classes here, DataVec should be able to deserialize them, in spite of the JSON
|
|
||||||
* format change between versions.
|
|
||||||
*
|
|
||||||
* @param classes Classes to register
|
|
||||||
*/
|
|
||||||
public static void registerLegacyCustomClassesForJSON(Class<?>... classes) {
|
|
||||||
registerLegacyCustomClassesForJSONList(Arrays.<Class<?>>asList(classes));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @see #registerLegacyCustomClassesForJSON(Class[])
|
|
||||||
*/
|
|
||||||
public static void registerLegacyCustomClassesForJSONList(List<Class<?>> classes){
|
|
||||||
//Default names (i.e., old format for custom JSON format)
|
|
||||||
List<Pair<String,Class>> list = new ArrayList<>();
|
|
||||||
for(Class<?> c : classes){
|
|
||||||
list.add(new Pair<String,Class>(c.getSimpleName(), c));
|
|
||||||
}
|
}
|
||||||
registerLegacyCustomClassesForJSON(list);
|
return legacyMapper;
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set of classes that can be registered for legacy deserialization.
|
|
||||||
*/
|
|
||||||
private static List<Class<?>> REGISTERABLE_CUSTOM_CLASSES = (List<Class<?>>) Arrays.<Class<?>>asList(
|
|
||||||
Transform.class,
|
|
||||||
ColumnAnalysis.class,
|
|
||||||
ColumnCondition.class,
|
|
||||||
Filter.class,
|
|
||||||
ColumnMetaData.class,
|
|
||||||
CalculateSortedRank.class,
|
|
||||||
Schema.class,
|
|
||||||
SequenceComparator.class,
|
|
||||||
SequenceSplit.class,
|
|
||||||
WindowFunction.class,
|
|
||||||
Writable.class,
|
|
||||||
WritableComparator.class
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
|
|
||||||
* ONLY) for JSON deserialization, with custom names.<br>
|
|
||||||
* Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])}
|
|
||||||
* but is added in case it is required in non-standard circumstances.
|
|
||||||
*/
|
|
||||||
public static void registerLegacyCustomClassesForJSON(List<Pair<String,Class>> classes){
|
|
||||||
for(Pair<String,Class> p : classes){
|
|
||||||
String s = p.getFirst();
|
|
||||||
Class c = p.getRight();
|
|
||||||
//Check if it's a valid class to register...
|
|
||||||
boolean found = false;
|
|
||||||
for( Class<?> c2 : REGISTERABLE_CUSTOM_CLASSES){
|
|
||||||
if(c2.isAssignableFrom(c)){
|
|
||||||
Map<String,String> map = LegacyMappingHelper.legacyMappingForClass(c2);
|
|
||||||
map.put(p.getFirst(), p.getSecond().getName());
|
|
||||||
found = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!found){
|
|
||||||
throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " +
|
|
||||||
c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the legacy JSON mapper for the specified class.<br>
|
|
||||||
*
|
|
||||||
* <b>NOTE</b>: This is intended for internal backward-compatibility use.
|
|
||||||
*
|
|
||||||
* Note to developers: The following JSON mappers are for handling legacy format JSON.
|
|
||||||
* Note that after 1.0.0-alpha, the JSON subtype format for Transforms, Filters, Conditions etc were changed from
|
|
||||||
* a wrapper object, to an "@class" field. However, to not break all saved transforms networks, these mappers are
|
|
||||||
* part of the solution.<br>
|
|
||||||
* <br>
|
|
||||||
* How legacy loading works (same pattern for all types - Transform, Filter, Condition etc)<br>
|
|
||||||
* 1. Transforms etc JSON that has a "@class" field are deserialized as normal<br>
|
|
||||||
* 2. Transforms JSON that don't have such a field are mapped (via Layer @JsonTypeInfo) to LegacyMappingHelper.TransformHelper<br>
|
|
||||||
* 3. LegacyMappingHelper.TransformHelper has a @JsonDeserialize annotation - we use LegacyMappingHelper.LegacyTransformDeserializer to handle it<br>
|
|
||||||
* 4. LegacyTransformDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names
|
|
||||||
* 5. BaseLegacyDeserializer (that LegacyTransformDeserializer extends) does a lookup and handles the deserialization
|
|
||||||
*
|
|
||||||
* Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format,
|
|
||||||
* as it'll fail due to not having the expected "@class" annotation.
|
|
||||||
* Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified
|
|
||||||
* class anyway. The ignoring is done via an annotation introspector, defined below in this class.
|
|
||||||
* However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of
|
|
||||||
* all types - if we did, then any nested types would fail (i.e., an Condition in a Transform - the Transform couldn't
|
|
||||||
* be deserialized correctly, as the annotation would be ignored).
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public static synchronized ObjectMapper getLegacyMapperFor(@NonNull Class<?> clazz){
|
|
||||||
if(!legacyMappers.containsKey(clazz)){
|
|
||||||
ObjectMapper m = new ObjectMapper();
|
|
||||||
configureMapper(m);
|
|
||||||
m.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.<Class>singletonList(clazz)));
|
|
||||||
legacyMappers.put(clazz, m);
|
|
||||||
}
|
|
||||||
return legacyMappers.get(clazz);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -237,61 +76,7 @@ public class JsonMappers {
|
||||||
ret.enable(SerializationFeature.INDENT_OUTPUT);
|
ret.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||||
ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||||
|
ret.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); //Need this otherwise JsonProperty annotations on constructors won't be seen
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc.
|
|
||||||
* This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring
|
|
||||||
* a set of JsonTypeInfo annotations
|
|
||||||
*/
|
|
||||||
@AllArgsConstructor
|
|
||||||
private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector {
|
|
||||||
|
|
||||||
private List<Class> classList;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected TypeResolverBuilder<?> _findTypeResolver(MapperConfig<?> config, Annotated ann, JavaType baseType) {
|
|
||||||
if(ann instanceof AnnotatedClass){
|
|
||||||
AnnotatedClass c = (AnnotatedClass)ann;
|
|
||||||
Class<?> annClass = c.getAnnotated();
|
|
||||||
|
|
||||||
boolean isAssignable = false;
|
|
||||||
for(Class c2 : classList){
|
|
||||||
if(c2.isAssignableFrom(annClass)){
|
|
||||||
isAssignable = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if( isAssignable ){
|
|
||||||
AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations();
|
|
||||||
if(annotations == null || annotations.annotations() == null){
|
|
||||||
//Probably not necessary - but here for safety
|
|
||||||
return super._findTypeResolver(config, ann, baseType);
|
|
||||||
}
|
|
||||||
|
|
||||||
AnnotationMap newMap = null;
|
|
||||||
for(Annotation a : annotations.annotations()){
|
|
||||||
Class<?> annType = a.annotationType();
|
|
||||||
if(annType == JsonTypeInfo.class){
|
|
||||||
//Ignore the JsonTypeInfo annotation on the Layer class
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if(newMap == null){
|
|
||||||
newMap = new AnnotationMap();
|
|
||||||
}
|
|
||||||
newMap.add(a);
|
|
||||||
}
|
|
||||||
if(newMap == null)
|
|
||||||
return null;
|
|
||||||
|
|
||||||
//Pass the remaining annotations (if any) to the original introspector
|
|
||||||
AnnotatedClass ann2 = c.withAnnotations(newMap);
|
|
||||||
return super._findTypeResolver(config, ann2, baseType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return super._findTypeResolver(config, ann, baseType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.serde.legacy;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.datavec.api.transform.serde.JsonMappers;
|
|
||||||
import org.nd4j.serde.json.BaseLegacyDeserializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Data
|
|
||||||
public class GenericLegacyDeserializer<T> extends BaseLegacyDeserializer<T> {
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected final Class<T> deserializedType;
|
|
||||||
@Getter
|
|
||||||
protected final Map<String,String> legacyNamesMap;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ObjectMapper getLegacyJsonMapper() {
|
|
||||||
return JsonMappers.getLegacyMapperFor(getDeserializedType());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,267 @@
|
||||||
|
package org.datavec.api.transform.serde.legacy;
|
||||||
|
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.datavec.api.transform.Transform;
|
||||||
|
import org.datavec.api.transform.analysis.columns.*;
|
||||||
|
import org.datavec.api.transform.condition.BooleanCondition;
|
||||||
|
import org.datavec.api.transform.condition.Condition;
|
||||||
|
import org.datavec.api.transform.condition.column.*;
|
||||||
|
import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
|
||||||
|
import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
||||||
|
import org.datavec.api.transform.filter.ConditionFilter;
|
||||||
|
import org.datavec.api.transform.filter.Filter;
|
||||||
|
import org.datavec.api.transform.filter.FilterInvalidValues;
|
||||||
|
import org.datavec.api.transform.filter.InvalidNumColumns;
|
||||||
|
import org.datavec.api.transform.metadata.*;
|
||||||
|
import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
|
||||||
|
import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
|
||||||
|
import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
|
||||||
|
import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
|
||||||
|
import org.datavec.api.transform.rank.CalculateSortedRank;
|
||||||
|
import org.datavec.api.transform.schema.Schema;
|
||||||
|
import org.datavec.api.transform.schema.SequenceSchema;
|
||||||
|
import org.datavec.api.transform.sequence.ReduceSequenceTransform;
|
||||||
|
import org.datavec.api.transform.sequence.SequenceComparator;
|
||||||
|
import org.datavec.api.transform.sequence.SequenceSplit;
|
||||||
|
import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
|
||||||
|
import org.datavec.api.transform.sequence.comparator.StringComparator;
|
||||||
|
import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
|
||||||
|
import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
|
||||||
|
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
|
||||||
|
import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
|
||||||
|
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
|
||||||
|
import org.datavec.api.transform.sequence.window.TimeWindowFunction;
|
||||||
|
import org.datavec.api.transform.sequence.window.WindowFunction;
|
||||||
|
import org.datavec.api.transform.stringreduce.IStringReducer;
|
||||||
|
import org.datavec.api.transform.stringreduce.StringReducer;
|
||||||
|
import org.datavec.api.transform.transform.categorical.*;
|
||||||
|
import org.datavec.api.transform.transform.column.*;
|
||||||
|
import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
|
||||||
|
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
|
||||||
|
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
|
||||||
|
import org.datavec.api.transform.transform.doubletransform.*;
|
||||||
|
import org.datavec.api.transform.transform.integer.*;
|
||||||
|
import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
|
||||||
|
import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
|
||||||
|
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
|
||||||
|
import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
|
||||||
|
import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
|
||||||
|
import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
|
||||||
|
import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
|
||||||
|
import org.datavec.api.transform.transform.string.*;
|
||||||
|
import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
|
||||||
|
import org.datavec.api.transform.transform.time.StringToTimeTransform;
|
||||||
|
import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
||||||
|
import org.datavec.api.writable.*;
|
||||||
|
import org.datavec.api.writable.comparator.*;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override
|
||||||
|
* the existing annotations.
|
||||||
|
* In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field".
|
||||||
|
* We use these mixins to allow us to still load the old format
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
public class LegacyJsonFormat {
|
||||||
|
|
||||||
|
private LegacyJsonFormat(){ }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before
|
||||||
|
* @return Object mapper
|
||||||
|
*/
|
||||||
|
public static ObjectMapper legacyMapper(){
|
||||||
|
ObjectMapper om = new ObjectMapper();
|
||||||
|
om.addMixIn(Schema.class, SchemaMixin.class);
|
||||||
|
om.addMixIn(ColumnMetaData.class, ColumnMetaDataMixin.class);
|
||||||
|
om.addMixIn(Transform.class, TransformMixin.class);
|
||||||
|
om.addMixIn(Condition.class, ConditionMixin.class);
|
||||||
|
om.addMixIn(Writable.class, WritableMixin.class);
|
||||||
|
om.addMixIn(Filter.class, FilterMixin.class);
|
||||||
|
om.addMixIn(SequenceComparator.class, SequenceComparatorMixin.class);
|
||||||
|
om.addMixIn(SequenceSplit.class, SequenceSplitMixin.class);
|
||||||
|
om.addMixIn(WindowFunction.class, WindowFunctionMixin.class);
|
||||||
|
om.addMixIn(CalculateSortedRank.class, CalculateSortedRankMixin.class);
|
||||||
|
om.addMixIn(WritableComparator.class, WritableComparatorMixin.class);
|
||||||
|
om.addMixIn(ColumnAnalysis.class, ColumnAnalysisMixin.class);
|
||||||
|
om.addMixIn(IStringReducer.class, IStringReducerMixin.class);
|
||||||
|
return om;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes({@JsonSubTypes.Type(value = Schema.class, name = "Schema"),
|
||||||
|
@JsonSubTypes.Type(value = SequenceSchema.class, name = "SequenceSchema")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class SchemaMixin { }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes({@JsonSubTypes.Type(value = BinaryMetaData.class, name = "Binary"),
|
||||||
|
@JsonSubTypes.Type(value = BooleanMetaData.class, name = "Boloean"),
|
||||||
|
@JsonSubTypes.Type(value = CategoricalMetaData.class, name = "Categorical"),
|
||||||
|
@JsonSubTypes.Type(value = DoubleMetaData.class, name = "Double"),
|
||||||
|
@JsonSubTypes.Type(value = FloatMetaData.class, name = "Float"),
|
||||||
|
@JsonSubTypes.Type(value = IntegerMetaData.class, name = "Integer"),
|
||||||
|
@JsonSubTypes.Type(value = LongMetaData.class, name = "Long"),
|
||||||
|
@JsonSubTypes.Type(value = NDArrayMetaData.class, name = "NDArray"),
|
||||||
|
@JsonSubTypes.Type(value = StringMetaData.class, name = "String"),
|
||||||
|
@JsonSubTypes.Type(value = TimeMetaData.class, name = "Time")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class ColumnMetaDataMixin { }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank"),
|
||||||
|
@JsonSubTypes.Type(value = CategoricalToIntegerTransform.class, name = "CategoricalToIntegerTransform"),
|
||||||
|
@JsonSubTypes.Type(value = CategoricalToOneHotTransform.class, name = "CategoricalToOneHotTransform"),
|
||||||
|
@JsonSubTypes.Type(value = IntegerToCategoricalTransform.class, name = "IntegerToCategoricalTransform"),
|
||||||
|
@JsonSubTypes.Type(value = StringToCategoricalTransform.class, name = "StringToCategoricalTransform"),
|
||||||
|
@JsonSubTypes.Type(value = DuplicateColumnsTransform.class, name = "DuplicateColumnsTransform"),
|
||||||
|
@JsonSubTypes.Type(value = RemoveColumnsTransform.class, name = "RemoveColumnsTransform"),
|
||||||
|
@JsonSubTypes.Type(value = RenameColumnsTransform.class, name = "RenameColumnsTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ReorderColumnsTransform.class, name = "ReorderColumnsTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ConditionalCopyValueTransform.class, name = "ConditionalCopyValueTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ConditionalReplaceValueTransform.class, name = "ConditionalReplaceValueTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ConditionalReplaceValueTransformWithDefault.class, name = "ConditionalReplaceValueTransformWithDefault"),
|
||||||
|
@JsonSubTypes.Type(value = DoubleColumnsMathOpTransform.class, name = "DoubleColumnsMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = DoubleMathOpTransform.class, name = "DoubleMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = Log2Normalizer.class, name = "Log2Normalizer"),
|
||||||
|
@JsonSubTypes.Type(value = MinMaxNormalizer.class, name = "MinMaxNormalizer"),
|
||||||
|
@JsonSubTypes.Type(value = StandardizeNormalizer.class, name = "StandardizeNormalizer"),
|
||||||
|
@JsonSubTypes.Type(value = SubtractMeanNormalizer.class, name = "SubtractMeanNormalizer"),
|
||||||
|
@JsonSubTypes.Type(value = IntegerColumnsMathOpTransform.class, name = "IntegerColumnsMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = IntegerMathOpTransform.class, name = "IntegerMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ReplaceEmptyIntegerWithValueTransform.class, name = "ReplaceEmptyIntegerWithValueTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ReplaceInvalidWithIntegerTransform.class, name = "ReplaceInvalidWithIntegerTransform"),
|
||||||
|
@JsonSubTypes.Type(value = LongColumnsMathOpTransform.class, name = "LongColumnsMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = LongMathOpTransform.class, name = "LongMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = MapAllStringsExceptListTransform.class, name = "MapAllStringsExceptListTransform"),
|
||||||
|
@JsonSubTypes.Type(value = RemoveWhiteSpaceTransform.class, name = "RemoveWhiteSpaceTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ReplaceEmptyStringTransform.class, name = "ReplaceEmptyStringTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ReplaceStringTransform.class, name = "ReplaceStringTransform"),
|
||||||
|
@JsonSubTypes.Type(value = StringListToCategoricalSetTransform.class, name = "StringListToCategoricalSetTransform"),
|
||||||
|
@JsonSubTypes.Type(value = StringMapTransform.class, name = "StringMapTransform"),
|
||||||
|
@JsonSubTypes.Type(value = DeriveColumnsFromTimeTransform.class, name = "DeriveColumnsFromTimeTransform"),
|
||||||
|
@JsonSubTypes.Type(value = StringToTimeTransform.class, name = "StringToTimeTransform"),
|
||||||
|
@JsonSubTypes.Type(value = TimeMathOpTransform.class, name = "TimeMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ReduceSequenceByWindowTransform.class, name = "ReduceSequenceByWindowTransform"),
|
||||||
|
@JsonSubTypes.Type(value = DoubleMathFunctionTransform.class, name = "DoubleMathFunctionTransform"),
|
||||||
|
@JsonSubTypes.Type(value = AddConstantColumnTransform.class, name = "AddConstantColumnTransform"),
|
||||||
|
@JsonSubTypes.Type(value = RemoveAllColumnsExceptForTransform.class, name = "RemoveAllColumnsExceptForTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ParseDoubleTransform.class, name = "ParseDoubleTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ConvertToString.class, name = "ConvertToStringTransform"),
|
||||||
|
@JsonSubTypes.Type(value = AppendStringColumnTransform.class, name = "AppendStringColumnTransform"),
|
||||||
|
@JsonSubTypes.Type(value = SequenceDifferenceTransform.class, name = "SequenceDifferenceTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ReduceSequenceTransform.class, name = "ReduceSequenceTransform"),
|
||||||
|
@JsonSubTypes.Type(value = SequenceMovingWindowReduceTransform.class, name = "SequenceMovingWindowReduceTransform"),
|
||||||
|
@JsonSubTypes.Type(value = IntegerToOneHotTransform.class, name = "IntegerToOneHotTransform"),
|
||||||
|
@JsonSubTypes.Type(value = SequenceTrimTransform.class, name = "SequenceTrimTransform"),
|
||||||
|
@JsonSubTypes.Type(value = SequenceOffsetTransform.class, name = "SequenceOffsetTransform"),
|
||||||
|
@JsonSubTypes.Type(value = NDArrayColumnsMathOpTransform.class, name = "NDArrayColumnsMathOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = NDArrayDistanceTransform.class, name = "NDArrayDistanceTransform"),
|
||||||
|
@JsonSubTypes.Type(value = NDArrayMathFunctionTransform.class, name = "NDArrayMathFunctionTransform"),
|
||||||
|
@JsonSubTypes.Type(value = NDArrayScalarOpTransform.class, name = "NDArrayScalarOpTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ChangeCaseStringTransform.class, name = "ChangeCaseStringTransform"),
|
||||||
|
@JsonSubTypes.Type(value = ConcatenateStringColumns.class, name = "ConcatenateStringColumns"),
|
||||||
|
@JsonSubTypes.Type(value = StringListToCountsNDArrayTransform.class, name = "StringListToCountsNDArrayTransform"),
|
||||||
|
@JsonSubTypes.Type(value = StringListToIndicesNDArrayTransform.class, name = "StringListToIndicesNDArrayTransform"),
|
||||||
|
@JsonSubTypes.Type(value = PivotTransform.class, name = "PivotTransform"),
|
||||||
|
@JsonSubTypes.Type(value = TextToCharacterIndexTransform.class, name = "TextToCharacterIndexTransform")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class TransformMixin { }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = TrivialColumnCondition.class, name = "TrivialColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = CategoricalColumnCondition.class, name = "CategoricalColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = DoubleColumnCondition.class, name = "DoubleColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = IntegerColumnCondition.class, name = "IntegerColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = LongColumnCondition.class, name = "LongColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = NullWritableColumnCondition.class, name = "NullWritableColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = StringColumnCondition.class, name = "StringColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = TimeColumnCondition.class, name = "TimeColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = StringRegexColumnCondition.class, name = "StringRegexColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = BooleanCondition.class, name = "BooleanCondition"),
|
||||||
|
@JsonSubTypes.Type(value = NaNColumnCondition.class, name = "NaNColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = InfiniteColumnCondition.class, name = "InfiniteColumnCondition"),
|
||||||
|
@JsonSubTypes.Type(value = SequenceLengthCondition.class, name = "SequenceLengthCondition")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class ConditionMixin { }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = ArrayWritable.class, name = "ArrayWritable"),
|
||||||
|
@JsonSubTypes.Type(value = BooleanWritable.class, name = "BooleanWritable"),
|
||||||
|
@JsonSubTypes.Type(value = ByteWritable.class, name = "ByteWritable"),
|
||||||
|
@JsonSubTypes.Type(value = DoubleWritable.class, name = "DoubleWritable"),
|
||||||
|
@JsonSubTypes.Type(value = FloatWritable.class, name = "FloatWritable"),
|
||||||
|
@JsonSubTypes.Type(value = IntWritable.class, name = "IntWritable"),
|
||||||
|
@JsonSubTypes.Type(value = LongWritable.class, name = "LongWritable"),
|
||||||
|
@JsonSubTypes.Type(value = NullWritable.class, name = "NullWritable"),
|
||||||
|
@JsonSubTypes.Type(value = Text.class, name = "Text"),
|
||||||
|
@JsonSubTypes.Type(value = BytesWritable.class, name = "BytesWritable")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class WritableMixin { }
|
||||||
|
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = ConditionFilter.class, name = "ConditionFilter"),
|
||||||
|
@JsonSubTypes.Type(value = FilterInvalidValues.class, name = "FilterInvalidValues"),
|
||||||
|
@JsonSubTypes.Type(value = InvalidNumColumns.class, name = "InvalidNumCols")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class FilterMixin { }
|
||||||
|
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = NumericalColumnComparator.class, name = "NumericalColumnComparator"),
|
||||||
|
@JsonSubTypes.Type(value = StringComparator.class, name = "StringComparator")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class SequenceComparatorMixin { }
|
||||||
|
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = SequenceSplitTimeSeparation.class, name = "SequenceSplitTimeSeparation"),
|
||||||
|
@JsonSubTypes.Type(value = SplitMaxLengthSequence.class, name = "SplitMaxLengthSequence")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class SequenceSplitMixin { }
|
||||||
|
|
||||||
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = TimeWindowFunction.class, name = "TimeWindowFunction"),
|
||||||
|
@JsonSubTypes.Type(value = OverlappingTimeWindowFunction.class, name = "OverlappingTimeWindowFunction")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class WindowFunctionMixin { }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class CalculateSortedRankMixin { }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = DoubleWritableComparator.class, name = "DoubleWritableComparator"),
|
||||||
|
@JsonSubTypes.Type(value = FloatWritableComparator.class, name = "FloatWritableComparator"),
|
||||||
|
@JsonSubTypes.Type(value = IntWritableComparator.class, name = "IntWritableComparator"),
|
||||||
|
@JsonSubTypes.Type(value = LongWritableComparator.class, name = "LongWritableComparator"),
|
||||||
|
@JsonSubTypes.Type(value = TextWritableComparator.class, name = "TextWritableComparator")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class WritableComparatorMixin { }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = BytesAnalysis.class, name = "BytesAnalysis"),
|
||||||
|
@JsonSubTypes.Type(value = CategoricalAnalysis.class, name = "CategoricalAnalysis"),
|
||||||
|
@JsonSubTypes.Type(value = DoubleAnalysis.class, name = "DoubleAnalysis"),
|
||||||
|
@JsonSubTypes.Type(value = IntegerAnalysis.class, name = "IntegerAnalysis"),
|
||||||
|
@JsonSubTypes.Type(value = LongAnalysis.class, name = "LongAnalysis"),
|
||||||
|
@JsonSubTypes.Type(value = StringAnalysis.class, name = "StringAnalysis"),
|
||||||
|
@JsonSubTypes.Type(value = TimeAnalysis.class, name = "TimeAnalysis")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class ColumnAnalysisMixin{ }
|
||||||
|
|
||||||
|
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||||
|
@JsonSubTypes(value = {@JsonSubTypes.Type(value = StringReducer.class, name = "StringReducer")})
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
public static class IStringReducerMixin{ }
|
||||||
|
}
|
|
@ -1,535 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.api.transform.serde.legacy;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.Transform;
|
|
||||||
import org.datavec.api.transform.analysis.columns.*;
|
|
||||||
import org.datavec.api.transform.condition.BooleanCondition;
|
|
||||||
import org.datavec.api.transform.condition.Condition;
|
|
||||||
import org.datavec.api.transform.condition.column.*;
|
|
||||||
import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
|
|
||||||
import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
|
||||||
import org.datavec.api.transform.filter.ConditionFilter;
|
|
||||||
import org.datavec.api.transform.filter.Filter;
|
|
||||||
import org.datavec.api.transform.filter.FilterInvalidValues;
|
|
||||||
import org.datavec.api.transform.filter.InvalidNumColumns;
|
|
||||||
import org.datavec.api.transform.metadata.*;
|
|
||||||
import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
|
|
||||||
import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
|
|
||||||
import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
|
|
||||||
import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
|
|
||||||
import org.datavec.api.transform.rank.CalculateSortedRank;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.schema.SequenceSchema;
|
|
||||||
import org.datavec.api.transform.sequence.ReduceSequenceTransform;
|
|
||||||
import org.datavec.api.transform.sequence.SequenceComparator;
|
|
||||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
|
||||||
import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
|
|
||||||
import org.datavec.api.transform.sequence.comparator.StringComparator;
|
|
||||||
import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
|
|
||||||
import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
|
|
||||||
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
|
|
||||||
import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
|
|
||||||
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
|
|
||||||
import org.datavec.api.transform.sequence.window.TimeWindowFunction;
|
|
||||||
import org.datavec.api.transform.sequence.window.WindowFunction;
|
|
||||||
import org.datavec.api.transform.stringreduce.IStringReducer;
|
|
||||||
import org.datavec.api.transform.stringreduce.StringReducer;
|
|
||||||
import org.datavec.api.transform.transform.categorical.*;
|
|
||||||
import org.datavec.api.transform.transform.column.*;
|
|
||||||
import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
|
|
||||||
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
|
|
||||||
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
|
|
||||||
import org.datavec.api.transform.transform.doubletransform.*;
|
|
||||||
import org.datavec.api.transform.transform.integer.*;
|
|
||||||
import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
|
|
||||||
import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
|
|
||||||
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
|
|
||||||
import org.datavec.api.transform.transform.nlp.TextToTermIndexSequenceTransform;
|
|
||||||
import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
|
|
||||||
import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
|
|
||||||
import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
|
|
||||||
import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
|
|
||||||
import org.datavec.api.transform.transform.string.*;
|
|
||||||
import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
|
|
||||||
import org.datavec.api.transform.transform.time.StringToTimeTransform;
|
|
||||||
import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
|
||||||
import org.datavec.api.writable.*;
|
|
||||||
import org.datavec.api.writable.comparator.*;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class LegacyMappingHelper {
|
|
||||||
|
|
||||||
public static Map<String,String> legacyMappingForClass(Class c){
|
|
||||||
//Need to be able to get the map - and they need to be mutable...
|
|
||||||
switch (c.getSimpleName()){
|
|
||||||
case "Transform":
|
|
||||||
return getLegacyMappingImageTransform();
|
|
||||||
case "ColumnAnalysis":
|
|
||||||
return getLegacyMappingColumnAnalysis();
|
|
||||||
case "Condition":
|
|
||||||
return getLegacyMappingCondition();
|
|
||||||
case "Filter":
|
|
||||||
return getLegacyMappingFilter();
|
|
||||||
case "ColumnMetaData":
|
|
||||||
return mapColumnMetaData;
|
|
||||||
case "CalculateSortedRank":
|
|
||||||
return mapCalculateSortedRank;
|
|
||||||
case "Schema":
|
|
||||||
return mapSchema;
|
|
||||||
case "SequenceComparator":
|
|
||||||
return mapSequenceComparator;
|
|
||||||
case "SequenceSplit":
|
|
||||||
return mapSequenceSplit;
|
|
||||||
case "WindowFunction":
|
|
||||||
return mapWindowFunction;
|
|
||||||
case "IStringReducer":
|
|
||||||
return mapIStringReducer;
|
|
||||||
case "Writable":
|
|
||||||
return mapWritable;
|
|
||||||
case "WritableComparator":
|
|
||||||
return mapWritableComparator;
|
|
||||||
case "ImageTransform":
|
|
||||||
return mapImageTransform;
|
|
||||||
default:
|
|
||||||
//Should never happen
|
|
||||||
throw new IllegalArgumentException("No legacy mapping available for class " + c.getName());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> mapTransform;
|
|
||||||
private static Map<String,String> mapColumnAnalysis;
|
|
||||||
private static Map<String,String> mapCondition;
|
|
||||||
private static Map<String,String> mapFilter;
|
|
||||||
private static Map<String,String> mapColumnMetaData;
|
|
||||||
private static Map<String,String> mapCalculateSortedRank;
|
|
||||||
private static Map<String,String> mapSchema;
|
|
||||||
private static Map<String,String> mapSequenceComparator;
|
|
||||||
private static Map<String,String> mapSequenceSplit;
|
|
||||||
private static Map<String,String> mapWindowFunction;
|
|
||||||
private static Map<String,String> mapIStringReducer;
|
|
||||||
private static Map<String,String> mapWritable;
|
|
||||||
private static Map<String,String> mapWritableComparator;
|
|
||||||
private static Map<String,String> mapImageTransform;
|
|
||||||
|
|
||||||
private static synchronized Map<String,String> getLegacyMappingTransform(){
|
|
||||||
|
|
||||||
if(mapTransform == null) {
|
|
||||||
//The following classes all used their class short name
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("CategoricalToIntegerTransform", CategoricalToIntegerTransform.class.getName());
|
|
||||||
m.put("CategoricalToOneHotTransform", CategoricalToOneHotTransform.class.getName());
|
|
||||||
m.put("IntegerToCategoricalTransform", IntegerToCategoricalTransform.class.getName());
|
|
||||||
m.put("StringToCategoricalTransform", StringToCategoricalTransform.class.getName());
|
|
||||||
m.put("DuplicateColumnsTransform", DuplicateColumnsTransform.class.getName());
|
|
||||||
m.put("RemoveColumnsTransform", RemoveColumnsTransform.class.getName());
|
|
||||||
m.put("RenameColumnsTransform", RenameColumnsTransform.class.getName());
|
|
||||||
m.put("ReorderColumnsTransform", ReorderColumnsTransform.class.getName());
|
|
||||||
m.put("ConditionalCopyValueTransform", ConditionalCopyValueTransform.class.getName());
|
|
||||||
m.put("ConditionalReplaceValueTransform", ConditionalReplaceValueTransform.class.getName());
|
|
||||||
m.put("ConditionalReplaceValueTransformWithDefault", ConditionalReplaceValueTransformWithDefault.class.getName());
|
|
||||||
m.put("DoubleColumnsMathOpTransform", DoubleColumnsMathOpTransform.class.getName());
|
|
||||||
m.put("DoubleMathOpTransform", DoubleMathOpTransform.class.getName());
|
|
||||||
m.put("Log2Normalizer", Log2Normalizer.class.getName());
|
|
||||||
m.put("MinMaxNormalizer", MinMaxNormalizer.class.getName());
|
|
||||||
m.put("StandardizeNormalizer", StandardizeNormalizer.class.getName());
|
|
||||||
m.put("SubtractMeanNormalizer", SubtractMeanNormalizer.class.getName());
|
|
||||||
m.put("IntegerColumnsMathOpTransform", IntegerColumnsMathOpTransform.class.getName());
|
|
||||||
m.put("IntegerMathOpTransform", IntegerMathOpTransform.class.getName());
|
|
||||||
m.put("ReplaceEmptyIntegerWithValueTransform", ReplaceEmptyIntegerWithValueTransform.class.getName());
|
|
||||||
m.put("ReplaceInvalidWithIntegerTransform", ReplaceInvalidWithIntegerTransform.class.getName());
|
|
||||||
m.put("LongColumnsMathOpTransform", LongColumnsMathOpTransform.class.getName());
|
|
||||||
m.put("LongMathOpTransform", LongMathOpTransform.class.getName());
|
|
||||||
m.put("MapAllStringsExceptListTransform", MapAllStringsExceptListTransform.class.getName());
|
|
||||||
m.put("RemoveWhiteSpaceTransform", RemoveWhiteSpaceTransform.class.getName());
|
|
||||||
m.put("ReplaceEmptyStringTransform", ReplaceEmptyStringTransform.class.getName());
|
|
||||||
m.put("ReplaceStringTransform", ReplaceStringTransform.class.getName());
|
|
||||||
m.put("StringListToCategoricalSetTransform", StringListToCategoricalSetTransform.class.getName());
|
|
||||||
m.put("StringMapTransform", StringMapTransform.class.getName());
|
|
||||||
m.put("DeriveColumnsFromTimeTransform", DeriveColumnsFromTimeTransform.class.getName());
|
|
||||||
m.put("StringToTimeTransform", StringToTimeTransform.class.getName());
|
|
||||||
m.put("TimeMathOpTransform", TimeMathOpTransform.class.getName());
|
|
||||||
m.put("ReduceSequenceByWindowTransform", ReduceSequenceByWindowTransform.class.getName());
|
|
||||||
m.put("DoubleMathFunctionTransform", DoubleMathFunctionTransform.class.getName());
|
|
||||||
m.put("AddConstantColumnTransform", AddConstantColumnTransform.class.getName());
|
|
||||||
m.put("RemoveAllColumnsExceptForTransform", RemoveAllColumnsExceptForTransform.class.getName());
|
|
||||||
m.put("ParseDoubleTransform", ParseDoubleTransform.class.getName());
|
|
||||||
m.put("ConvertToStringTransform", ConvertToString.class.getName());
|
|
||||||
m.put("AppendStringColumnTransform", AppendStringColumnTransform.class.getName());
|
|
||||||
m.put("SequenceDifferenceTransform", SequenceDifferenceTransform.class.getName());
|
|
||||||
m.put("ReduceSequenceTransform", ReduceSequenceTransform.class.getName());
|
|
||||||
m.put("SequenceMovingWindowReduceTransform", SequenceMovingWindowReduceTransform.class.getName());
|
|
||||||
m.put("IntegerToOneHotTransform", IntegerToOneHotTransform.class.getName());
|
|
||||||
m.put("SequenceTrimTransform", SequenceTrimTransform.class.getName());
|
|
||||||
m.put("SequenceOffsetTransform", SequenceOffsetTransform.class.getName());
|
|
||||||
m.put("NDArrayColumnsMathOpTransform", NDArrayColumnsMathOpTransform.class.getName());
|
|
||||||
m.put("NDArrayDistanceTransform", NDArrayDistanceTransform.class.getName());
|
|
||||||
m.put("NDArrayMathFunctionTransform", NDArrayMathFunctionTransform.class.getName());
|
|
||||||
m.put("NDArrayScalarOpTransform", NDArrayScalarOpTransform.class.getName());
|
|
||||||
m.put("ChangeCaseStringTransform", ChangeCaseStringTransform.class.getName());
|
|
||||||
m.put("ConcatenateStringColumns", ConcatenateStringColumns.class.getName());
|
|
||||||
m.put("StringListToCountsNDArrayTransform", StringListToCountsNDArrayTransform.class.getName());
|
|
||||||
m.put("StringListToIndicesNDArrayTransform", StringListToIndicesNDArrayTransform.class.getName());
|
|
||||||
m.put("PivotTransform", PivotTransform.class.getName());
|
|
||||||
m.put("TextToCharacterIndexTransform", TextToCharacterIndexTransform.class.getName());
|
|
||||||
|
|
||||||
//The following never had subtype annotations, and hence will have had the default name:
|
|
||||||
m.put(TextToTermIndexSequenceTransform.class.getSimpleName(), TextToTermIndexSequenceTransform.class.getName());
|
|
||||||
m.put(ConvertToInteger.class.getSimpleName(), ConvertToInteger.class.getName());
|
|
||||||
m.put(ConvertToDouble.class.getSimpleName(), ConvertToDouble.class.getName());
|
|
||||||
|
|
||||||
mapTransform = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapTransform;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingColumnAnalysis(){
|
|
||||||
if(mapColumnAnalysis == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("BytesAnalysis", BytesAnalysis.class.getName());
|
|
||||||
m.put("CategoricalAnalysis", CategoricalAnalysis.class.getName());
|
|
||||||
m.put("DoubleAnalysis", DoubleAnalysis.class.getName());
|
|
||||||
m.put("IntegerAnalysis", IntegerAnalysis.class.getName());
|
|
||||||
m.put("LongAnalysis", LongAnalysis.class.getName());
|
|
||||||
m.put("StringAnalysis", StringAnalysis.class.getName());
|
|
||||||
m.put("TimeAnalysis", TimeAnalysis.class.getName());
|
|
||||||
|
|
||||||
//The following never had subtype annotations, and hence will have had the default name:
|
|
||||||
m.put(NDArrayAnalysis.class.getSimpleName(), NDArrayAnalysis.class.getName());
|
|
||||||
|
|
||||||
mapColumnAnalysis = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapColumnAnalysis;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingCondition(){
|
|
||||||
if(mapCondition == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("TrivialColumnCondition", TrivialColumnCondition.class.getName());
|
|
||||||
m.put("CategoricalColumnCondition", CategoricalColumnCondition.class.getName());
|
|
||||||
m.put("DoubleColumnCondition", DoubleColumnCondition.class.getName());
|
|
||||||
m.put("IntegerColumnCondition", IntegerColumnCondition.class.getName());
|
|
||||||
m.put("LongColumnCondition", LongColumnCondition.class.getName());
|
|
||||||
m.put("NullWritableColumnCondition", NullWritableColumnCondition.class.getName());
|
|
||||||
m.put("StringColumnCondition", StringColumnCondition.class.getName());
|
|
||||||
m.put("TimeColumnCondition", TimeColumnCondition.class.getName());
|
|
||||||
m.put("StringRegexColumnCondition", StringRegexColumnCondition.class.getName());
|
|
||||||
m.put("BooleanCondition", BooleanCondition.class.getName());
|
|
||||||
m.put("NaNColumnCondition", NaNColumnCondition.class.getName());
|
|
||||||
m.put("InfiniteColumnCondition", InfiniteColumnCondition.class.getName());
|
|
||||||
m.put("SequenceLengthCondition", SequenceLengthCondition.class.getName());
|
|
||||||
|
|
||||||
//The following never had subtype annotations, and hence will have had the default name:
|
|
||||||
m.put(InvalidValueColumnCondition.class.getSimpleName(), InvalidValueColumnCondition.class.getName());
|
|
||||||
m.put(BooleanColumnCondition.class.getSimpleName(), BooleanColumnCondition.class.getName());
|
|
||||||
|
|
||||||
mapCondition = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapCondition;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingFilter(){
|
|
||||||
if(mapFilter == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("ConditionFilter", ConditionFilter.class.getName());
|
|
||||||
m.put("FilterInvalidValues", FilterInvalidValues.class.getName());
|
|
||||||
m.put("InvalidNumCols", InvalidNumColumns.class.getName());
|
|
||||||
|
|
||||||
mapFilter = m;
|
|
||||||
}
|
|
||||||
return mapFilter;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingColumnMetaData(){
|
|
||||||
if(mapColumnMetaData == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("Categorical", CategoricalMetaData.class.getName());
|
|
||||||
m.put("Double", DoubleMetaData.class.getName());
|
|
||||||
m.put("Float", FloatMetaData.class.getName());
|
|
||||||
m.put("Integer", IntegerMetaData.class.getName());
|
|
||||||
m.put("Long", LongMetaData.class.getName());
|
|
||||||
m.put("String", StringMetaData.class.getName());
|
|
||||||
m.put("Time", TimeMetaData.class.getName());
|
|
||||||
m.put("NDArray", NDArrayMetaData.class.getName());
|
|
||||||
|
|
||||||
//The following never had subtype annotations, and hence will have had the default name:
|
|
||||||
m.put(BooleanMetaData.class.getSimpleName(), BooleanMetaData.class.getName());
|
|
||||||
m.put(BinaryMetaData.class.getSimpleName(), BinaryMetaData.class.getName());
|
|
||||||
|
|
||||||
mapColumnMetaData = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapColumnMetaData;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingCalculateSortedRank(){
|
|
||||||
if(mapCalculateSortedRank == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("CalculateSortedRank", CalculateSortedRank.class.getName());
|
|
||||||
mapCalculateSortedRank = m;
|
|
||||||
}
|
|
||||||
return mapCalculateSortedRank;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingSchema(){
|
|
||||||
if(mapSchema == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("Schema", Schema.class.getName());
|
|
||||||
m.put("SequenceSchema", SequenceSchema.class.getName());
|
|
||||||
|
|
||||||
mapSchema = m;
|
|
||||||
}
|
|
||||||
return mapSchema;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingSequenceComparator(){
|
|
||||||
if(mapSequenceComparator == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("NumericalColumnComparator", NumericalColumnComparator.class.getName());
|
|
||||||
m.put("StringComparator", StringComparator.class.getName());
|
|
||||||
|
|
||||||
mapSequenceComparator = m;
|
|
||||||
}
|
|
||||||
return mapSequenceComparator;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingSequenceSplit(){
|
|
||||||
if(mapSequenceSplit == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("SequenceSplitTimeSeparation", SequenceSplitTimeSeparation.class.getName());
|
|
||||||
m.put("SplitMaxLengthSequence", SplitMaxLengthSequence.class.getName());
|
|
||||||
|
|
||||||
mapSequenceSplit = m;
|
|
||||||
}
|
|
||||||
return mapSequenceSplit;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingWindowFunction(){
|
|
||||||
if(mapWindowFunction == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("TimeWindowFunction", TimeWindowFunction.class.getName());
|
|
||||||
m.put("OverlappingTimeWindowFunction", OverlappingTimeWindowFunction.class.getName());
|
|
||||||
|
|
||||||
mapWindowFunction = m;
|
|
||||||
}
|
|
||||||
return mapWindowFunction;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingIStringReducer(){
|
|
||||||
if(mapIStringReducer == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("StringReducer", StringReducer.class.getName());
|
|
||||||
|
|
||||||
mapIStringReducer = m;
|
|
||||||
}
|
|
||||||
return mapIStringReducer;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingWritable(){
|
|
||||||
if (mapWritable == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("ArrayWritable", ArrayWritable.class.getName());
|
|
||||||
m.put("BooleanWritable", BooleanWritable.class.getName());
|
|
||||||
m.put("ByteWritable", ByteWritable.class.getName());
|
|
||||||
m.put("DoubleWritable", DoubleWritable.class.getName());
|
|
||||||
m.put("FloatWritable", FloatWritable.class.getName());
|
|
||||||
m.put("IntWritable", IntWritable.class.getName());
|
|
||||||
m.put("LongWritable", LongWritable.class.getName());
|
|
||||||
m.put("NullWritable", NullWritable.class.getName());
|
|
||||||
m.put("Text", Text.class.getName());
|
|
||||||
m.put("BytesWritable", BytesWritable.class.getName());
|
|
||||||
|
|
||||||
//The following never had subtype annotations, and hence will have had the default name:
|
|
||||||
m.put(NDArrayWritable.class.getSimpleName(), NDArrayWritable.class.getName());
|
|
||||||
|
|
||||||
mapWritable = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapWritable;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<String,String> getLegacyMappingWritableComparator(){
|
|
||||||
if(mapWritableComparator == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("DoubleWritableComparator", DoubleWritableComparator.class.getName());
|
|
||||||
m.put("FloatWritableComparator", FloatWritableComparator.class.getName());
|
|
||||||
m.put("IntWritableComparator", IntWritableComparator.class.getName());
|
|
||||||
m.put("LongWritableComparator", LongWritableComparator.class.getName());
|
|
||||||
m.put("TextWritableComparator", TextWritableComparator.class.getName());
|
|
||||||
|
|
||||||
//The following never had subtype annotations, and hence will have had the default name:
|
|
||||||
m.put(ByteWritable.Comparator.class.getSimpleName(), ByteWritable.Comparator.class.getName());
|
|
||||||
m.put(FloatWritable.Comparator.class.getSimpleName(), FloatWritable.Comparator.class.getName());
|
|
||||||
m.put(IntWritable.Comparator.class.getSimpleName(), IntWritable.Comparator.class.getName());
|
|
||||||
m.put(BooleanWritable.Comparator.class.getSimpleName(), BooleanWritable.Comparator.class.getName());
|
|
||||||
m.put(LongWritable.Comparator.class.getSimpleName(), LongWritable.Comparator.class.getName());
|
|
||||||
m.put(Text.Comparator.class.getSimpleName(), Text.Comparator.class.getName());
|
|
||||||
m.put(LongWritable.DecreasingComparator.class.getSimpleName(), LongWritable.DecreasingComparator.class.getName());
|
|
||||||
m.put(DoubleWritable.Comparator.class.getSimpleName(), DoubleWritable.Comparator.class.getName());
|
|
||||||
|
|
||||||
mapWritableComparator = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapWritableComparator;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Map<String,String> getLegacyMappingImageTransform(){
|
|
||||||
if(mapImageTransform == null) {
|
|
||||||
Map<String, String> m = new HashMap<>();
|
|
||||||
m.put("EqualizeHistTransform", "org.datavec.image.transform.EqualizeHistTransform");
|
|
||||||
m.put("RotateImageTransform", "org.datavec.image.transform.RotateImageTransform");
|
|
||||||
m.put("ColorConversionTransform", "org.datavec.image.transform.ColorConversionTransform");
|
|
||||||
m.put("WarpImageTransform", "org.datavec.image.transform.WarpImageTransform");
|
|
||||||
m.put("BoxImageTransform", "org.datavec.image.transform.BoxImageTransform");
|
|
||||||
m.put("CropImageTransform", "org.datavec.image.transform.CropImageTransform");
|
|
||||||
m.put("FilterImageTransform", "org.datavec.image.transform.FilterImageTransform");
|
|
||||||
m.put("FlipImageTransform", "org.datavec.image.transform.FlipImageTransform");
|
|
||||||
m.put("LargestBlobCropTransform", "org.datavec.image.transform.LargestBlobCropTransform");
|
|
||||||
m.put("ResizeImageTransform", "org.datavec.image.transform.ResizeImageTransform");
|
|
||||||
m.put("RandomCropTransform", "org.datavec.image.transform.RandomCropTransform");
|
|
||||||
m.put("ScaleImageTransform", "org.datavec.image.transform.ScaleImageTransform");
|
|
||||||
|
|
||||||
mapImageTransform = m;
|
|
||||||
}
|
|
||||||
return mapImageTransform;
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyTransformDeserializer.class)
|
|
||||||
public static class TransformHelper { }
|
|
||||||
|
|
||||||
public static class LegacyTransformDeserializer extends GenericLegacyDeserializer<Transform> {
|
|
||||||
public LegacyTransformDeserializer() {
|
|
||||||
super(Transform.class, getLegacyMappingTransform());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyColumnAnalysisDeserializer.class)
|
|
||||||
public static class ColumnAnalysisHelper { }
|
|
||||||
|
|
||||||
public static class LegacyColumnAnalysisDeserializer extends GenericLegacyDeserializer<ColumnAnalysis> {
|
|
||||||
public LegacyColumnAnalysisDeserializer() {
|
|
||||||
super(ColumnAnalysis.class, getLegacyMappingColumnAnalysis());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyConditionDeserializer.class)
|
|
||||||
public static class ConditionHelper { }
|
|
||||||
|
|
||||||
public static class LegacyConditionDeserializer extends GenericLegacyDeserializer<Condition> {
|
|
||||||
public LegacyConditionDeserializer() {
|
|
||||||
super(Condition.class, getLegacyMappingCondition());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyFilterDeserializer.class)
|
|
||||||
public static class FilterHelper { }
|
|
||||||
|
|
||||||
public static class LegacyFilterDeserializer extends GenericLegacyDeserializer<Filter> {
|
|
||||||
public LegacyFilterDeserializer() {
|
|
||||||
super(Filter.class, getLegacyMappingFilter());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyColumnMetaDataDeserializer.class)
|
|
||||||
public static class ColumnMetaDataHelper { }
|
|
||||||
|
|
||||||
public static class LegacyColumnMetaDataDeserializer extends GenericLegacyDeserializer<ColumnMetaData> {
|
|
||||||
public LegacyColumnMetaDataDeserializer() {
|
|
||||||
super(ColumnMetaData.class, getLegacyMappingColumnMetaData());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyCalculateSortedRankDeserializer.class)
|
|
||||||
public static class CalculateSortedRankHelper { }
|
|
||||||
|
|
||||||
public static class LegacyCalculateSortedRankDeserializer extends GenericLegacyDeserializer<CalculateSortedRank> {
|
|
||||||
public LegacyCalculateSortedRankDeserializer() {
|
|
||||||
super(CalculateSortedRank.class, getLegacyMappingCalculateSortedRank());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacySchemaDeserializer.class)
|
|
||||||
public static class SchemaHelper { }
|
|
||||||
|
|
||||||
public static class LegacySchemaDeserializer extends GenericLegacyDeserializer<Schema> {
|
|
||||||
public LegacySchemaDeserializer() {
|
|
||||||
super(Schema.class, getLegacyMappingSchema());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacySequenceComparatorDeserializer.class)
|
|
||||||
public static class SequenceComparatorHelper { }
|
|
||||||
|
|
||||||
public static class LegacySequenceComparatorDeserializer extends GenericLegacyDeserializer<SequenceComparator> {
|
|
||||||
public LegacySequenceComparatorDeserializer() {
|
|
||||||
super(SequenceComparator.class, getLegacyMappingSequenceComparator());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacySequenceSplitDeserializer.class)
|
|
||||||
public static class SequenceSplitHelper { }
|
|
||||||
|
|
||||||
public static class LegacySequenceSplitDeserializer extends GenericLegacyDeserializer<SequenceSplit> {
|
|
||||||
public LegacySequenceSplitDeserializer() {
|
|
||||||
super(SequenceSplit.class, getLegacyMappingSequenceSplit());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyWindowFunctionDeserializer.class)
|
|
||||||
public static class WindowFunctionHelper { }
|
|
||||||
|
|
||||||
public static class LegacyWindowFunctionDeserializer extends GenericLegacyDeserializer<WindowFunction> {
|
|
||||||
public LegacyWindowFunctionDeserializer() {
|
|
||||||
super(WindowFunction.class, getLegacyMappingWindowFunction());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyIStringReducerDeserializer.class)
|
|
||||||
public static class IStringReducerHelper { }
|
|
||||||
|
|
||||||
public static class LegacyIStringReducerDeserializer extends GenericLegacyDeserializer<IStringReducer> {
|
|
||||||
public LegacyIStringReducerDeserializer() {
|
|
||||||
super(IStringReducer.class, getLegacyMappingIStringReducer());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyWritableDeserializer.class)
|
|
||||||
public static class WritableHelper { }
|
|
||||||
|
|
||||||
public static class LegacyWritableDeserializer extends GenericLegacyDeserializer<Writable> {
|
|
||||||
public LegacyWritableDeserializer() {
|
|
||||||
super(Writable.class, getLegacyMappingWritable());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyWritableComparatorDeserializer.class)
|
|
||||||
public static class WritableComparatorHelper { }
|
|
||||||
|
|
||||||
public static class LegacyWritableComparatorDeserializer extends GenericLegacyDeserializer<WritableComparator> {
|
|
||||||
public LegacyWritableComparatorDeserializer() {
|
|
||||||
super(WritableComparator.class, getLegacyMappingWritableComparator());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -17,7 +17,6 @@
|
||||||
package org.datavec.api.transform.stringreduce;
|
package org.datavec.api.transform.stringreduce;
|
||||||
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
@ -31,8 +30,7 @@ import java.util.List;
|
||||||
* a single List<Writable>
|
* a single List<Writable>
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.IStringReducerHelper.class)
|
|
||||||
public interface IStringReducer extends Serializable {
|
public interface IStringReducer extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.util.ndarray;
|
package org.datavec.api.util.ndarray;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
|
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.math.DoubleMath;
|
import org.nd4j.shade.guava.math.DoubleMath;
|
||||||
import org.datavec.api.io.WritableComparable;
|
import org.datavec.api.io.WritableComparable;
|
||||||
import org.datavec.api.io.WritableComparator;
|
import org.datavec.api.io.WritableComparator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.math.DoubleMath;
|
import org.nd4j.shade.guava.math.DoubleMath;
|
||||||
import org.datavec.api.io.WritableComparable;
|
import org.datavec.api.io.WritableComparable;
|
||||||
import org.datavec.api.io.WritableComparator;
|
import org.datavec.api.io.WritableComparator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.math.DoubleMath;
|
import org.nd4j.shade.guava.math.DoubleMath;
|
||||||
import org.datavec.api.io.WritableComparable;
|
import org.datavec.api.io.WritableComparable;
|
||||||
import org.datavec.api.io.WritableComparator;
|
import org.datavec.api.io.WritableComparator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.math.DoubleMath;
|
import org.nd4j.shade.guava.math.DoubleMath;
|
||||||
import org.datavec.api.io.WritableComparable;
|
import org.datavec.api.io.WritableComparable;
|
||||||
import org.datavec.api.io.WritableComparator;
|
import org.datavec.api.io.WritableComparator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.math.DoubleMath;
|
import org.nd4j.shade.guava.math.DoubleMath;
|
||||||
import org.datavec.api.io.WritableComparable;
|
import org.datavec.api.io.WritableComparable;
|
||||||
import org.datavec.api.io.WritableComparator;
|
import org.datavec.api.io.WritableComparator;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.DataInput;
|
import java.io.DataInput;
|
||||||
|
@ -60,8 +59,7 @@ import java.io.Serializable;
|
||||||
* }
|
* }
|
||||||
* </pre></blockquote></p>
|
* </pre></blockquote></p>
|
||||||
*/
|
*/
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.WritableHelper.class)
|
|
||||||
public interface Writable extends Serializable {
|
public interface Writable extends Serializable {
|
||||||
/**
|
/**
|
||||||
* Serialize the fields of this object to <code>out</code>.
|
* Serialize the fields of this object to <code>out</code>.
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.writable.batch;
|
package org.datavec.api.writable.batch;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
import org.datavec.api.writable.NDArrayWritable;
|
||||||
|
|
|
@ -16,16 +16,13 @@
|
||||||
|
|
||||||
package org.datavec.api.writable.comparator;
|
package org.datavec.api.writable.comparator;
|
||||||
|
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyMappingHelper.WritableComparatorHelper.class)
|
|
||||||
public interface WritableComparator extends Comparator<Writable>, Serializable {
|
public interface WritableComparator extends Comparator<Writable>, Serializable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.split;
|
package org.datavec.api.split;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.datavec.api.io.filters.BalancedPathFilter;
|
import org.datavec.api.io.filters.BalancedPathFilter;
|
||||||
import org.datavec.api.io.filters.RandomPathFilter;
|
import org.datavec.api.io.filters.RandomPathFilter;
|
||||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.split.parittion;
|
package org.datavec.api.split.parittion;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.datavec.api.conf.Configuration;
|
import org.datavec.api.conf.Configuration;
|
||||||
import org.datavec.api.split.FileSplit;
|
import org.datavec.api.split.FileSplit;
|
||||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||||
|
|
|
@ -78,8 +78,9 @@ public class TestJsonYaml {
|
||||||
public void testMissingPrimitives() {
|
public void testMissingPrimitives() {
|
||||||
|
|
||||||
Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build();
|
Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build();
|
||||||
|
//Legacy format JSON
|
||||||
String strJson = "{\n" + " \"Schema\" : {\n" + " \"columns\" : [ {\n" + " \"Double\" : {\n"
|
String strJson = "{\n" + " \"Schema\" : {\n"
|
||||||
|
+ " \"columns\" : [ {\n" + " \"Double\" : {\n"
|
||||||
+ " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" +
|
+ " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" +
|
||||||
//" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test
|
//" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test
|
||||||
//" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test
|
//" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.writable;
|
package org.datavec.api.writable;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import org.nd4j.shade.guava.collect.Lists;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
|
@ -34,36 +34,6 @@
|
||||||
<artifactId>nd4j-arrow</artifactId>
|
<artifactId>nd4j-arrow</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-core</artifactId>
|
|
||||||
<version>${spark2.jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-databind</artifactId>
|
|
||||||
<version>${spark2.jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-annotations</artifactId>
|
|
||||||
<version>${spark2.jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
|
||||||
<artifactId>jackson-dataformat-yaml</artifactId>
|
|
||||||
<version>${spark2.jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
|
||||||
<artifactId>jackson-dataformat-xml</artifactId>
|
|
||||||
<version>${spark2.jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
|
||||||
<artifactId>jackson-datatype-joda</artifactId>
|
|
||||||
<version>${spark2.jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.datavec</groupId>
|
<groupId>org.datavec</groupId>
|
||||||
<artifactId>datavec-api</artifactId>
|
<artifactId>datavec-api</artifactId>
|
||||||
|
@ -74,11 +44,6 @@
|
||||||
<artifactId>hppc</artifactId>
|
<artifactId>hppc</artifactId>
|
||||||
<version>${hppc.version}</version>
|
<version>${hppc.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.guava</groupId>
|
|
||||||
<artifactId>guava</artifactId>
|
|
||||||
<version>${guava.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.arrow</groupId>
|
<groupId>org.apache.arrow</groupId>
|
||||||
<artifactId>arrow-vector</artifactId>
|
<artifactId>arrow-vector</artifactId>
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.image.recordreader;
|
package org.datavec.image.recordreader;
|
||||||
|
|
||||||
import com.google.common.base.Preconditions;
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.image.serde;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.serde.legacy.GenericLegacyDeserializer;
|
|
||||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
|
||||||
import org.datavec.image.transform.ImageTransform;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
|
||||||
|
|
||||||
public class LegacyImageMappingHelper {
|
|
||||||
|
|
||||||
@JsonDeserialize(using = LegacyImageTransformDeserializer.class)
|
|
||||||
public static class ImageTransformHelper { }
|
|
||||||
|
|
||||||
public static class LegacyImageTransformDeserializer extends GenericLegacyDeserializer<ImageTransform> {
|
|
||||||
public LegacyImageTransformDeserializer() {
|
|
||||||
super(ImageTransform.class, LegacyMappingHelper.getLegacyMappingImageTransform());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -16,11 +16,8 @@
|
||||||
|
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.datavec.image.serde.LegacyImageMappingHelper;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
@ -31,8 +28,7 @@ import java.util.Random;
|
||||||
* @author saudet
|
* @author saudet
|
||||||
*/
|
*/
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
defaultImpl = LegacyImageMappingHelper.ImageTransformHelper.class)
|
|
||||||
public interface ImageTransform {
|
public interface ImageTransform {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -50,11 +50,6 @@
|
||||||
<artifactId>netty</artifactId>
|
<artifactId>netty</artifactId>
|
||||||
<version>${netty.version}</version>
|
<version>${netty.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.guava</groupId>
|
|
||||||
<artifactId>guava</artifactId>
|
|
||||||
<version>${guava.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
<artifactId>commons-compress</artifactId>
|
<artifactId>commons-compress</artifactId>
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.hadoop.records.reader;
|
package org.datavec.hadoop.records.reader;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.fs.Path;
|
import org.apache.hadoop.fs.Path;
|
||||||
import org.apache.hadoop.io.*;
|
import org.apache.hadoop.io.*;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.hadoop.records.reader;
|
package org.datavec.hadoop.records.reader;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.fs.Path;
|
import org.apache.hadoop.fs.Path;
|
||||||
import org.apache.hadoop.io.*;
|
import org.apache.hadoop.io.*;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.hadoop.records.reader;
|
package org.datavec.hadoop.records.reader;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.apache.hadoop.conf.Configuration;
|
import org.apache.hadoop.conf.Configuration;
|
||||||
import org.apache.hadoop.fs.Path;
|
import org.apache.hadoop.fs.Path;
|
||||||
import org.apache.hadoop.io.*;
|
import org.apache.hadoop.io.*;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.hadoop.records.writer;
|
package org.datavec.hadoop.records.writer;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.datavec.api.records.converter.RecordReaderConverter;
|
import org.datavec.api.records.converter.RecordReaderConverter;
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
import org.datavec.api.records.reader.RecordReader;
|
||||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.local.transforms.join;
|
package org.datavec.local.transforms.join;
|
||||||
|
|
||||||
import com.google.common.collect.Iterables;
|
import org.nd4j.shade.guava.collect.Iterables;
|
||||||
import org.datavec.api.transform.join.Join;
|
import org.datavec.api.transform.join.Join;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.local.transforms.functions.FlatMapFunctionAdapter;
|
import org.datavec.local.transforms.functions.FlatMapFunctionAdapter;
|
||||||
|
|
|
@ -64,12 +64,6 @@
|
||||||
<version>${datavec.version}</version>
|
<version>${datavec.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-cluster_2.11</artifactId>
|
|
||||||
<version>${akka.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>joda-time</groupId>
|
<groupId>joda-time</groupId>
|
||||||
<artifactId>joda-time</artifactId>
|
<artifactId>joda-time</artifactId>
|
||||||
|
@ -106,40 +100,10 @@
|
||||||
<version>${snakeyaml.version}</version>
|
<version>${snakeyaml.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-core</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-databind</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-annotations</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
|
||||||
<artifactId>jackson-datatype-jdk8</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
|
||||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>com.typesafe.play</groupId>
|
||||||
<artifactId>play-java_2.11</artifactId>
|
<artifactId>play-java_2.11</artifactId>
|
||||||
<version>${play.version}</version>
|
<version>${playframework.version}</version>
|
||||||
<exclusions>
|
<exclusions>
|
||||||
<exclusion>
|
<exclusion>
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
|
@ -161,25 +125,31 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>com.typesafe.play</groupId>
|
||||||
<artifactId>play-json_2.11</artifactId>
|
<artifactId>play-json_2.11</artifactId>
|
||||||
<version>${play.version}</version>
|
<version>${playframework.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>com.typesafe.play</groupId>
|
||||||
<artifactId>play-server_2.11</artifactId>
|
<artifactId>play-server_2.11</artifactId>
|
||||||
<version>${play.version}</version>
|
<version>${playframework.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>com.typesafe.play</groupId>
|
||||||
<artifactId>play_2.11</artifactId>
|
<artifactId>play_2.11</artifactId>
|
||||||
<version>${play.version}</version>
|
<version>${playframework.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.typesafe.play</groupId>
|
<groupId>com.typesafe.play</groupId>
|
||||||
<artifactId>play-netty-server_2.11</artifactId>
|
<artifactId>play-netty-server_2.11</artifactId>
|
||||||
<version>${play.version}</version>
|
<version>${playframework.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.typesafe.akka</groupId>
|
||||||
|
<artifactId>akka-cluster_2.11</artifactId>
|
||||||
|
<version>2.5.23</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
@ -194,6 +164,12 @@
|
||||||
<artifactId>jcommander</artifactId>
|
<artifactId>jcommander</artifactId>
|
||||||
<version>${jcommander.version}</version>
|
<version>${jcommander.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-core_2.11</artifactId>
|
||||||
|
<version>${spark.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -24,12 +24,16 @@ import org.apache.commons.io.FileUtils;
|
||||||
import org.datavec.api.transform.TransformProcess;
|
import org.datavec.api.transform.TransformProcess;
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
import org.datavec.image.transform.ImageTransformProcess;
|
||||||
import org.datavec.spark.transform.model.*;
|
import org.datavec.spark.transform.model.*;
|
||||||
|
import play.BuiltInComponents;
|
||||||
import play.Mode;
|
import play.Mode;
|
||||||
|
import play.routing.Router;
|
||||||
import play.routing.RoutingDsl;
|
import play.routing.RoutingDsl;
|
||||||
import play.server.Server;
|
import play.server.Server;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import static play.mvc.Results.*;
|
import static play.mvc.Results.*;
|
||||||
|
|
||||||
|
@ -66,9 +70,6 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
RoutingDsl routingDsl = new RoutingDsl();
|
|
||||||
|
|
||||||
|
|
||||||
if (jsonPath != null) {
|
if (jsonPath != null) {
|
||||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
String json = FileUtils.readFileToString(new File(jsonPath));
|
||||||
TransformProcess transformProcess = TransformProcess.fromJson(json);
|
TransformProcess transformProcess = TransformProcess.fromJson(json);
|
||||||
|
@ -78,8 +79,26 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
+ "to /transformprocess");
|
+ "to /transformprocess");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Set play secret key, if required
|
||||||
|
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||||
|
String crypto = System.getProperty("play.crypto.secret");
|
||||||
|
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
|
||||||
|
byte[] newCrypto = new byte[1024];
|
||||||
|
|
||||||
routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
new Random().nextBytes(newCrypto);
|
||||||
|
|
||||||
|
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
||||||
|
System.setProperty("play.crypto.secret", base64);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Router createRouter(BuiltInComponents b){
|
||||||
|
RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
|
||||||
|
|
||||||
|
routingDsl.GET("/transformprocess").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
if (transform == null)
|
if (transform == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
|
@ -88,11 +107,11 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
log.error("Error in GET /transformprocess",e);
|
log.error("Error in GET /transformprocess",e);
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transformprocess").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText());
|
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
|
||||||
setCSVTransformProcess(transformProcess);
|
setCSVTransformProcess(transformProcess);
|
||||||
log.info("Transform process initialized");
|
log.info("Transform process initialized");
|
||||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
||||||
|
@ -100,12 +119,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
log.error("Error in POST /transformprocess",e);
|
log.error("Error in POST /transformprocess",e);
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transformincremental").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transformincremental").routingTo(req -> {
|
||||||
if (isSequence()) {
|
if (isSequence(req)) {
|
||||||
try {
|
try {
|
||||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||||
if (record == null)
|
if (record == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
|
||||||
|
@ -115,7 +134,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class);
|
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
||||||
if (record == null)
|
if (record == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
|
||||||
|
@ -124,12 +143,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transform").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transform").routingTo(req -> {
|
||||||
if (isSequence()) {
|
if (isSequence(req)) {
|
||||||
try {
|
try {
|
||||||
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class));
|
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
|
||||||
if (batch == null)
|
if (batch == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
|
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
|
||||||
|
@ -139,7 +158,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
BatchCSVRecord input = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||||
BatchCSVRecord batch = transform(input);
|
BatchCSVRecord batch = transform(input);
|
||||||
if (batch == null)
|
if (batch == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
|
@ -149,14 +168,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
||||||
})));
|
if (isSequence(req)) {
|
||||||
|
|
||||||
routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> {
|
|
||||||
if (isSequence()) {
|
|
||||||
try {
|
try {
|
||||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||||
if (record == null)
|
if (record == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
|
||||||
|
@ -166,7 +183,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class);
|
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
||||||
if (record == null)
|
if (record == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
|
||||||
|
@ -175,13 +192,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
})));
|
routingDsl.POST("/transformarray").routingTo(req -> {
|
||||||
|
if (isSequence(req)) {
|
||||||
routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> {
|
|
||||||
if (isSequence()) {
|
|
||||||
try {
|
try {
|
||||||
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class);
|
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
|
||||||
if (batchCSVRecord == null)
|
if (batchCSVRecord == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
|
||||||
|
@ -191,7 +207,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||||
if (batchCSVRecord == null)
|
if (batchCSVRecord == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
|
||||||
|
@ -200,10 +216,9 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
|
return routingDsl.build();
|
||||||
server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import play.libs.F;
|
|
||||||
import play.mvc.Result;
|
|
||||||
|
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Supplier;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility methods for Routing
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class FunctionUtil {
|
|
||||||
|
|
||||||
|
|
||||||
public static F.Function0<Result> function0(Supplier<Result> supplier) {
|
|
||||||
return supplier::get;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> F.Function<T, Result> function(Function<T, Result> function) {
|
|
||||||
return function::apply;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -24,8 +24,11 @@ import org.apache.commons.io.FileUtils;
|
||||||
import org.datavec.api.transform.TransformProcess;
|
import org.datavec.api.transform.TransformProcess;
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
import org.datavec.image.transform.ImageTransformProcess;
|
||||||
import org.datavec.spark.transform.model.*;
|
import org.datavec.spark.transform.model.*;
|
||||||
|
import play.BuiltInComponents;
|
||||||
import play.Mode;
|
import play.Mode;
|
||||||
|
import play.libs.Files;
|
||||||
import play.mvc.Http;
|
import play.mvc.Http;
|
||||||
|
import play.routing.Router;
|
||||||
import play.routing.RoutingDsl;
|
import play.routing.RoutingDsl;
|
||||||
import play.server.Server;
|
import play.server.Server;
|
||||||
|
|
||||||
|
@ -33,6 +36,7 @@ import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.function.Function;
|
||||||
|
|
||||||
import static play.mvc.Controller.request;
|
import static play.mvc.Controller.request;
|
||||||
import static play.mvc.Results.*;
|
import static play.mvc.Results.*;
|
||||||
|
@ -62,8 +66,6 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
System.exit(1);
|
System.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
RoutingDsl routingDsl = new RoutingDsl();
|
|
||||||
|
|
||||||
if (jsonPath != null) {
|
if (jsonPath != null) {
|
||||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
String json = FileUtils.readFileToString(new File(jsonPath));
|
||||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
|
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
|
||||||
|
@ -73,7 +75,13 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
+ "to /transformprocess");
|
+ "to /transformprocess");
|
||||||
}
|
}
|
||||||
|
|
||||||
routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Router createRouter(BuiltInComponents builtInComponents){
|
||||||
|
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
||||||
|
|
||||||
|
routingDsl.GET("/transformprocess").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
if (transform == null)
|
if (transform == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
|
@ -83,11 +91,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transformprocess").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText());
|
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
|
||||||
setImageTransformProcess(transformProcess);
|
setImageTransformProcess(transformProcess);
|
||||||
log.info("Transform process initialized");
|
log.info("Transform process initialized");
|
||||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
||||||
|
@ -95,11 +103,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
SingleImageRecord record = objectMapper.readValue(getJsonText(), SingleImageRecord.class);
|
SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
|
||||||
if (record == null)
|
if (record == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||||
|
@ -107,17 +115,17 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transformincrementalimage").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transformincrementalimage").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
Http.MultipartFormData body = request().body().asMultipartFormData();
|
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
||||||
List<Http.MultipartFormData.FilePart> files = body.getFiles();
|
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
||||||
if (files.size() == 0 || files.get(0).getFile() == null) {
|
if (files.isEmpty() || files.get(0).getRef() == null ) {
|
||||||
return badRequest();
|
return badRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
File file = files.get(0).getFile();
|
File file = files.get(0).getRef().path().toFile();
|
||||||
SingleImageRecord record = new SingleImageRecord(file.toURI());
|
SingleImageRecord record = new SingleImageRecord(file.toURI());
|
||||||
|
|
||||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||||
|
@ -125,11 +133,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transformarray").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
BatchImageRecord batch = objectMapper.readValue(getJsonText(), BatchImageRecord.class);
|
BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
|
||||||
if (batch == null)
|
if (batch == null)
|
||||||
return badRequest();
|
return badRequest();
|
||||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
||||||
|
@ -137,22 +145,22 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/transformimage").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/transformimage").routingTo(req -> {
|
||||||
try {
|
try {
|
||||||
Http.MultipartFormData body = request().body().asMultipartFormData();
|
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
||||||
List<Http.MultipartFormData.FilePart> files = body.getFiles();
|
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
||||||
if (files.size() == 0) {
|
if (files.size() == 0) {
|
||||||
return badRequest();
|
return badRequest();
|
||||||
}
|
}
|
||||||
|
|
||||||
List<SingleImageRecord> records = new ArrayList<>();
|
List<SingleImageRecord> records = new ArrayList<>();
|
||||||
|
|
||||||
for (Http.MultipartFormData.FilePart filePart : files) {
|
for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) {
|
||||||
File file = filePart.getFile();
|
Files.TemporaryFile file = filePart.getRef();
|
||||||
if (file != null) {
|
if (file != null) {
|
||||||
SingleImageRecord record = new SingleImageRecord(file.toURI());
|
SingleImageRecord record = new SingleImageRecord(file.path().toUri());
|
||||||
records.add(record);
|
records.add(record);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -164,9 +172,9 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError();
|
return internalServerError();
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
|
return routingDsl.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody;
|
||||||
import org.datavec.spark.transform.model.BatchCSVRecord;
|
import org.datavec.spark.transform.model.BatchCSVRecord;
|
||||||
import org.datavec.spark.transform.service.DataVecTransformService;
|
import org.datavec.spark.transform.service.DataVecTransformService;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
import play.mvc.Http;
|
||||||
import play.server.Server;
|
import play.server.Server;
|
||||||
|
|
||||||
import static play.mvc.Controller.request;
|
import static play.mvc.Controller.request;
|
||||||
|
@ -50,25 +51,17 @@ public abstract class SparkTransformServer implements DataVecTransformService {
|
||||||
server.stop();
|
server.stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean isSequence() {
|
protected boolean isSequence(Http.Request request) {
|
||||||
return request().hasHeader(SEQUENCE_OR_NOT_HEADER)
|
return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
|
||||||
&& request().getHeader(SEQUENCE_OR_NOT_HEADER).toUpperCase()
|
&& request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
|
||||||
.equals("TRUE");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected String getJsonText(Http.Request request) {
|
||||||
protected String getHeaderValue(String value) {
|
JsonNode tryJson = request.body().asJson();
|
||||||
if (request().hasHeader(value))
|
|
||||||
return request().getHeader(value);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected String getJsonText() {
|
|
||||||
JsonNode tryJson = request().body().asJson();
|
|
||||||
if (tryJson != null)
|
if (tryJson != null)
|
||||||
return tryJson.toString();
|
return tryJson.toString();
|
||||||
else
|
else
|
||||||
return request().body().asText();
|
return request.body().asText();
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
|
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
|
||||||
|
|
|
@ -0,0 +1,350 @@
|
||||||
|
# This is the main configuration file for the application.
|
||||||
|
# https://www.playframework.com/documentation/latest/ConfigFile
|
||||||
|
# ~~~~~
|
||||||
|
# Play uses HOCON as its configuration file format. HOCON has a number
|
||||||
|
# of advantages over other config formats, but there are two things that
|
||||||
|
# can be used when modifying settings.
|
||||||
|
#
|
||||||
|
# You can include other configuration files in this main application.conf file:
|
||||||
|
#include "extra-config.conf"
|
||||||
|
#
|
||||||
|
# You can declare variables and substitute for them:
|
||||||
|
#mykey = ${some.value}
|
||||||
|
#
|
||||||
|
# And if an environment variable exists when there is no other subsitution, then
|
||||||
|
# HOCON will fall back to substituting environment variable:
|
||||||
|
#mykey = ${JAVA_HOME}
|
||||||
|
|
||||||
|
## Akka
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
|
||||||
|
# ~~~~~
|
||||||
|
# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
|
||||||
|
# other streaming HTTP responses.
|
||||||
|
akka {
|
||||||
|
# "akka.log-config-on-start" is extraordinarly useful because it log the complete
|
||||||
|
# configuration at INFO level, including defaults and overrides, so it s worth
|
||||||
|
# putting at the very top.
|
||||||
|
#
|
||||||
|
# Put the following in your conf/logback.xml file:
|
||||||
|
#
|
||||||
|
# <logger name="akka.actor" level="INFO" />
|
||||||
|
#
|
||||||
|
# And then uncomment this line to debug the configuration.
|
||||||
|
#
|
||||||
|
#log-config-on-start = true
|
||||||
|
}
|
||||||
|
|
||||||
|
## Modules
|
||||||
|
# https://www.playframework.com/documentation/latest/Modules
|
||||||
|
# ~~~~~
|
||||||
|
# Control which modules are loaded when Play starts. Note that modules are
|
||||||
|
# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
|
||||||
|
# Please see https://www.playframework.com/documentation/latest/GlobalSettings
|
||||||
|
# for more information.
|
||||||
|
#
|
||||||
|
# You can also extend Play functionality by using one of the publically available
|
||||||
|
# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
|
||||||
|
play.modules {
|
||||||
|
# By default, Play will load any class called Module that is defined
|
||||||
|
# in the root package (the "app" directory), or you can define them
|
||||||
|
# explicitly below.
|
||||||
|
# If there are any built-in modules that you want to disable, you can list them here.
|
||||||
|
#enabled += my.application.Module
|
||||||
|
|
||||||
|
# If there are any built-in modules that you want to disable, you can list them here.
|
||||||
|
#disabled += ""
|
||||||
|
}
|
||||||
|
|
||||||
|
## Internationalisation
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaI18N
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaI18N
|
||||||
|
# ~~~~~
|
||||||
|
# Play comes with its own i18n settings, which allow the user's preferred language
|
||||||
|
# to map through to internal messages, or allow the language to be stored in a cookie.
|
||||||
|
play.i18n {
|
||||||
|
# The application languages
|
||||||
|
langs = [ "en" ]
|
||||||
|
|
||||||
|
# Whether the language cookie should be secure or not
|
||||||
|
#langCookieSecure = true
|
||||||
|
|
||||||
|
# Whether the HTTP only attribute of the cookie should be set to true
|
||||||
|
#langCookieHttpOnly = true
|
||||||
|
}
|
||||||
|
|
||||||
|
## Play HTTP settings
|
||||||
|
# ~~~~~
|
||||||
|
play.http {
|
||||||
|
## Router
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaRouting
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaRouting
|
||||||
|
# ~~~~~
|
||||||
|
# Define the Router object to use for this application.
|
||||||
|
# This router will be looked up first when the application is starting up,
|
||||||
|
# so make sure this is the entry point.
|
||||||
|
# Furthermore, it's assumed your route file is named properly.
|
||||||
|
# So for an application router like `my.application.Router`,
|
||||||
|
# you may need to define a router file `conf/my.application.routes`.
|
||||||
|
# Default to Routes in the root package (aka "apps" folder) (and conf/routes)
|
||||||
|
#router = my.application.Router
|
||||||
|
|
||||||
|
## Action Creator
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaActionCreator
|
||||||
|
# ~~~~~
|
||||||
|
#actionCreator = null
|
||||||
|
|
||||||
|
## ErrorHandler
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaRouting
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaRouting
|
||||||
|
# ~~~~~
|
||||||
|
# If null, will attempt to load a class called ErrorHandler in the root package,
|
||||||
|
#errorHandler = null
|
||||||
|
|
||||||
|
## Filters
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaHttpFilters
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaHttpFilters
|
||||||
|
# ~~~~~
|
||||||
|
# Filters run code on every request. They can be used to perform
|
||||||
|
# common logic for all your actions, e.g. adding common headers.
|
||||||
|
# Defaults to "Filters" in the root package (aka "apps" folder)
|
||||||
|
# Alternatively you can explicitly register a class here.
|
||||||
|
#filters += my.application.Filters
|
||||||
|
|
||||||
|
## Session & Flash
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaSessionFlash
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaSessionFlash
|
||||||
|
# ~~~~~
|
||||||
|
session {
|
||||||
|
# Sets the cookie to be sent only over HTTPS.
|
||||||
|
#secure = true
|
||||||
|
|
||||||
|
# Sets the cookie to be accessed only by the server.
|
||||||
|
#httpOnly = true
|
||||||
|
|
||||||
|
# Sets the max-age field of the cookie to 5 minutes.
|
||||||
|
# NOTE: this only sets when the browser will discard the cookie. Play will consider any
|
||||||
|
# cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
|
||||||
|
# you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
|
||||||
|
#maxAge = 300
|
||||||
|
|
||||||
|
# Sets the domain on the session cookie.
|
||||||
|
#domain = "example.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
flash {
|
||||||
|
# Sets the cookie to be sent only over HTTPS.
|
||||||
|
#secure = true
|
||||||
|
|
||||||
|
# Sets the cookie to be accessed only by the server.
|
||||||
|
#httpOnly = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
## Netty Provider
|
||||||
|
# https://www.playframework.com/documentation/latest/SettingsNetty
|
||||||
|
# ~~~~~
|
||||||
|
play.server.netty {
|
||||||
|
# Whether the Netty wire should be logged
|
||||||
|
#log.wire = true
|
||||||
|
|
||||||
|
# If you run Play on Linux, you can use Netty's native socket transport
|
||||||
|
# for higher performance with less garbage.
|
||||||
|
#transport = "native"
|
||||||
|
}
|
||||||
|
|
||||||
|
## WS (HTTP Client)
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
|
||||||
|
# ~~~~~
|
||||||
|
# The HTTP client primarily used for REST APIs. The default client can be
|
||||||
|
# configured directly, but you can also create different client instances
|
||||||
|
# with customized settings. You must enable this by adding to build.sbt:
|
||||||
|
#
|
||||||
|
# libraryDependencies += ws // or javaWs if using java
|
||||||
|
#
|
||||||
|
play.ws {
|
||||||
|
# Sets HTTP requests not to follow 302 requests
|
||||||
|
#followRedirects = false
|
||||||
|
|
||||||
|
# Sets the maximum number of open HTTP connections for the client.
|
||||||
|
#ahc.maxConnectionsTotal = 50
|
||||||
|
|
||||||
|
## WS SSL
|
||||||
|
# https://www.playframework.com/documentation/latest/WsSSL
|
||||||
|
# ~~~~~
|
||||||
|
ssl {
|
||||||
|
# Configuring HTTPS with Play WS does not require programming. You can
|
||||||
|
# set up both trustManager and keyManager for mutual authentication, and
|
||||||
|
# turn on JSSE debugging in development with a reload.
|
||||||
|
#debug.handshake = true
|
||||||
|
#trustManager = {
|
||||||
|
# stores = [
|
||||||
|
# { type = "JKS", path = "exampletrust.jks" }
|
||||||
|
# ]
|
||||||
|
#}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
## Cache
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaCache
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaCache
|
||||||
|
# ~~~~~
|
||||||
|
# Play comes with an integrated cache API that can reduce the operational
|
||||||
|
# overhead of repeated requests. You must enable this by adding to build.sbt:
|
||||||
|
#
|
||||||
|
# libraryDependencies += cache
|
||||||
|
#
|
||||||
|
play.cache {
|
||||||
|
# If you want to bind several caches, you can bind the individually
|
||||||
|
#bindCaches = ["db-cache", "user-cache", "session-cache"]
|
||||||
|
}
|
||||||
|
|
||||||
|
## Filters
|
||||||
|
# https://www.playframework.com/documentation/latest/Filters
|
||||||
|
# ~~~~~
|
||||||
|
# There are a number of built-in filters that can be enabled and configured
|
||||||
|
# to give Play greater security. You must enable this by adding to build.sbt:
|
||||||
|
#
|
||||||
|
# libraryDependencies += filters
|
||||||
|
#
|
||||||
|
play.filters {
|
||||||
|
## CORS filter configuration
|
||||||
|
# https://www.playframework.com/documentation/latest/CorsFilter
|
||||||
|
# ~~~~~
|
||||||
|
# CORS is a protocol that allows web applications to make requests from the browser
|
||||||
|
# across different domains.
|
||||||
|
# NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
|
||||||
|
# dependencies on CORS settings.
|
||||||
|
cors {
|
||||||
|
# Filter paths by a whitelist of path prefixes
|
||||||
|
#pathPrefixes = ["/some/path", ...]
|
||||||
|
|
||||||
|
# The allowed origins. If null, all origins are allowed.
|
||||||
|
#allowedOrigins = ["http://www.example.com"]
|
||||||
|
|
||||||
|
# The allowed HTTP methods. If null, all methods are allowed
|
||||||
|
#allowedHttpMethods = ["GET", "POST"]
|
||||||
|
}
|
||||||
|
|
||||||
|
## CSRF Filter
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
|
||||||
|
# ~~~~~
|
||||||
|
# Play supports multiple methods for verifying that a request is not a CSRF request.
|
||||||
|
# The primary mechanism is a CSRF token. This token gets placed either in the query string
|
||||||
|
# or body of every form submitted, and also gets placed in the users session.
|
||||||
|
# Play then verifies that both tokens are present and match.
|
||||||
|
csrf {
|
||||||
|
# Sets the cookie to be sent only over HTTPS
|
||||||
|
#cookie.secure = true
|
||||||
|
|
||||||
|
# Defaults to CSRFErrorHandler in the root package.
|
||||||
|
#errorHandler = MyCSRFErrorHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
## Security headers filter configuration
|
||||||
|
# https://www.playframework.com/documentation/latest/SecurityHeaders
|
||||||
|
# ~~~~~
|
||||||
|
# Defines security headers that prevent XSS attacks.
|
||||||
|
# If enabled, then all options are set to the below configuration by default:
|
||||||
|
headers {
|
||||||
|
# The X-Frame-Options header. If null, the header is not set.
|
||||||
|
#frameOptions = "DENY"
|
||||||
|
|
||||||
|
# The X-XSS-Protection header. If null, the header is not set.
|
||||||
|
#xssProtection = "1; mode=block"
|
||||||
|
|
||||||
|
# The X-Content-Type-Options header. If null, the header is not set.
|
||||||
|
#contentTypeOptions = "nosniff"
|
||||||
|
|
||||||
|
# The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
|
||||||
|
#permittedCrossDomainPolicies = "master-only"
|
||||||
|
|
||||||
|
# The Content-Security-Policy header. If null, the header is not set.
|
||||||
|
#contentSecurityPolicy = "default-src 'self'"
|
||||||
|
}
|
||||||
|
|
||||||
|
## Allowed hosts filter configuration
|
||||||
|
# https://www.playframework.com/documentation/latest/AllowedHostsFilter
|
||||||
|
# ~~~~~
|
||||||
|
# Play provides a filter that lets you configure which hosts can access your application.
|
||||||
|
# This is useful to prevent cache poisoning attacks.
|
||||||
|
hosts {
|
||||||
|
# Allow requests to example.com, its subdomains, and localhost:9000.
|
||||||
|
#allowed = [".example.com", "localhost:9000"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
## Evolutions
|
||||||
|
# https://www.playframework.com/documentation/latest/Evolutions
|
||||||
|
# ~~~~~
|
||||||
|
# Evolutions allows database scripts to be automatically run on startup in dev mode
|
||||||
|
# for database migrations. You must enable this by adding to build.sbt:
|
||||||
|
#
|
||||||
|
# libraryDependencies += evolutions
|
||||||
|
#
|
||||||
|
play.evolutions {
|
||||||
|
# You can disable evolutions for a specific datasource if necessary
|
||||||
|
#db.default.enabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
## Database Connection Pool
|
||||||
|
# https://www.playframework.com/documentation/latest/SettingsJDBC
|
||||||
|
# ~~~~~
|
||||||
|
# Play doesn't require a JDBC database to run, but you can easily enable one.
|
||||||
|
#
|
||||||
|
# libraryDependencies += jdbc
|
||||||
|
#
|
||||||
|
play.db {
|
||||||
|
# The combination of these two settings results in "db.default" as the
|
||||||
|
# default JDBC pool:
|
||||||
|
#config = "db"
|
||||||
|
#default = "default"
|
||||||
|
|
||||||
|
# Play uses HikariCP as the default connection pool. You can override
|
||||||
|
# settings by changing the prototype:
|
||||||
|
prototype {
|
||||||
|
# Sets a fixed JDBC connection pool size of 50
|
||||||
|
#hikaricp.minimumIdle = 50
|
||||||
|
#hikaricp.maximumPoolSize = 50
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
## JDBC Datasource
|
||||||
|
# https://www.playframework.com/documentation/latest/JavaDatabase
|
||||||
|
# https://www.playframework.com/documentation/latest/ScalaDatabase
|
||||||
|
# ~~~~~
|
||||||
|
# Once JDBC datasource is set up, you can work with several different
|
||||||
|
# database options:
|
||||||
|
#
|
||||||
|
# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
|
||||||
|
# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
|
||||||
|
# EBean: https://playframework.com/documentation/latest/JavaEbean
|
||||||
|
# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
|
||||||
|
#
|
||||||
|
db {
|
||||||
|
# You can declare as many datasources as you want.
|
||||||
|
# By convention, the default datasource is named `default`
|
||||||
|
|
||||||
|
# https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
|
||||||
|
default.driver = org.h2.Driver
|
||||||
|
default.url = "jdbc:h2:mem:play"
|
||||||
|
#default.username = sa
|
||||||
|
#default.password = ""
|
||||||
|
|
||||||
|
# You can expose this datasource via JNDI if needed (Useful for JPA)
|
||||||
|
default.jndiName=DefaultDS
|
||||||
|
|
||||||
|
# You can turn on SQL logging for any datasource
|
||||||
|
# https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
|
||||||
|
#default.logSql=true
|
||||||
|
}
|
||||||
|
|
||||||
|
jpa.default=defaultPersistenceUnit
|
||||||
|
|
||||||
|
|
||||||
|
#Increase default maximum post length - used for remote listener functionality
|
||||||
|
#Can get response 413 with larger networks without setting this
|
||||||
|
# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
|
||||||
|
#parsers.text.maxLength=10M
|
||||||
|
play.http.parser.maxMemoryBuffer=10M
|
|
@ -28,61 +28,11 @@
|
||||||
<artifactId>datavec-spark_2.11</artifactId>
|
<artifactId>datavec-spark_2.11</artifactId>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<!-- These Spark versions have to be here, and NOT in the datavec parent pom. Otherwise, the properties
|
|
||||||
will be resolved to their defaults. Whereas when defined here (the version outputStream spark 1 vs. 2 specific) we can
|
|
||||||
have different version properties simultaneously (in different pom files) instead of a single global property -->
|
|
||||||
<spark.version>2.1.0</spark.version>
|
|
||||||
<spark.major.version>2</spark.major.version>
|
|
||||||
|
|
||||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
<!-- Default scala versions, may be overwritten by build profiles -->
|
||||||
<scala.version>2.11.12</scala.version>
|
<scala.version>2.11.12</scala.version>
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
<scala.binary.version>2.11</scala.binary.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<!-- added source folder containing the code specific to the spark version -->
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.codehaus.mojo</groupId>
|
|
||||||
<artifactId>build-helper-maven-plugin</artifactId>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>add-source</id>
|
|
||||||
<phase>generate-sources</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>add-source</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<sources>
|
|
||||||
<source>src/main/spark-${spark.major.version}</source>
|
|
||||||
</sources>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
|
|
||||||
<dependencyManagement>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
|
||||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
|
||||||
<artifactId>jackson-dataformat-yaml</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.module</groupId>
|
|
||||||
<artifactId>jackson-module-scala_2.11</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</dependencyManagement>
|
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.scala-lang</groupId>
|
<groupId>org.scala-lang</groupId>
|
||||||
|
@ -95,42 +45,13 @@
|
||||||
<version>${scala.version}</version>
|
<version>${scala.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.codehaus.jackson</groupId>
|
|
||||||
<artifactId>jackson-core-asl</artifactId>
|
|
||||||
<version>${jackson-asl.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.codehaus.jackson</groupId>
|
|
||||||
<artifactId>jackson-mapper-asl</artifactId>
|
|
||||||
<version>${jackson-asl.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-sql_2.11</artifactId>
|
<artifactId>spark-sql_2.11</artifactId>
|
||||||
<version>${spark.version}</version>
|
<version>${spark.version}</version>
|
||||||
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.guava</groupId>
|
|
||||||
<artifactId>guava</artifactId>
|
|
||||||
<version>${guava.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.inject</groupId>
|
|
||||||
<artifactId>guice</artifactId>
|
|
||||||
<version>${guice.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.protobuf</groupId>
|
|
||||||
<artifactId>protobuf-java</artifactId>
|
|
||||||
<version>${google.protobuf.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>commons-codec</groupId>
|
|
||||||
<artifactId>commons-codec</artifactId>
|
|
||||||
<version>${commons-codec.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>commons-collections</groupId>
|
<groupId>commons-collections</groupId>
|
||||||
<artifactId>commons-collections</artifactId>
|
<artifactId>commons-collections</artifactId>
|
||||||
|
@ -141,96 +62,16 @@
|
||||||
<artifactId>commons-io</artifactId>
|
<artifactId>commons-io</artifactId>
|
||||||
<version>${commons-io.version}</version>
|
<version>${commons-io.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>commons-lang</groupId>
|
|
||||||
<artifactId>commons-lang</artifactId>
|
|
||||||
<version>${commons-lang.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>commons-net</groupId>
|
|
||||||
<artifactId>commons-net</artifactId>
|
|
||||||
<version>${commons-net.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.sun.xml.bind</groupId>
|
|
||||||
<artifactId>jaxb-core</artifactId>
|
|
||||||
<version>${jaxb.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.sun.xml.bind</groupId>
|
|
||||||
<artifactId>jaxb-impl</artifactId>
|
|
||||||
<version>${jaxb.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-actor_2.11</artifactId>
|
|
||||||
<version>${akka.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-remote_2.11</artifactId>
|
|
||||||
<version>${akka.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-slf4j_2.11</artifactId>
|
|
||||||
<version>${akka.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>io.netty</groupId>
|
|
||||||
<artifactId>netty</artifactId>
|
|
||||||
<version>${netty.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-core</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-databind</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
|
||||||
<artifactId>jackson-annotations</artifactId>
|
|
||||||
<version>${jackson.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>javax.servlet</groupId>
|
|
||||||
<artifactId>javax.servlet-api</artifactId>
|
|
||||||
<version>${servlet.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-compress</artifactId>
|
|
||||||
<version>${commons-compress.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
<version>${commons-lang3.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
<artifactId>commons-math3</artifactId>
|
<artifactId>commons-math3</artifactId>
|
||||||
<version>${commons-math3.version}</version>
|
<version>${commons-math3.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.curator</groupId>
|
|
||||||
<artifactId>curator-recipes</artifactId>
|
|
||||||
<version>${curator.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.slf4j</groupId>
|
<groupId>org.slf4j</groupId>
|
||||||
<artifactId>slf4j-api</artifactId>
|
<artifactId>slf4j-api</artifactId>
|
||||||
<version>${slf4j.version}</version>
|
<version>${slf4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe</groupId>
|
|
||||||
<artifactId>config</artifactId>
|
|
||||||
<version>${typesafe.config.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-core_2.11</artifactId>
|
<artifactId>spark-core_2.11</artifactId>
|
||||||
|
@ -241,14 +82,6 @@
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
<groupId>com.google.code.findbugs</groupId>
|
||||||
<artifactId>jsr305</artifactId>
|
<artifactId>jsr305</artifactId>
|
||||||
</exclusion>
|
</exclusion>
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-log4j12</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>log4j</groupId>
|
|
||||||
<artifactId>log4j</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.functions;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* A function that returns zero or more output records from each input record.
|
|
||||||
*
|
|
||||||
* Adapter for Spark interface in order to freeze interface changes between spark versions
|
|
||||||
*/
|
|
||||||
public interface FlatMapFunctionAdapter<T, R> extends Serializable {
|
|
||||||
Iterable<R> call(T t) throws Exception;
|
|
||||||
}
|
|
|
@ -21,10 +21,7 @@ import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.api.java.function.Function;
|
import org.apache.spark.api.java.function.Function;
|
||||||
import org.apache.spark.api.java.function.Function2;
|
import org.apache.spark.api.java.function.Function2;
|
||||||
import org.apache.spark.sql.Column;
|
import org.apache.spark.sql.*;
|
||||||
import org.apache.spark.sql.Row;
|
|
||||||
import org.apache.spark.sql.SQLContext;
|
|
||||||
import org.apache.spark.sql.functions;
|
|
||||||
import org.apache.spark.sql.types.DataTypes;
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
import org.apache.spark.sql.types.Metadata;
|
import org.apache.spark.sql.types.Metadata;
|
||||||
import org.apache.spark.sql.types.StructField;
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
@ -46,7 +43,6 @@ import java.util.List;
|
||||||
|
|
||||||
import static org.apache.spark.sql.functions.avg;
|
import static org.apache.spark.sql.functions.avg;
|
||||||
import static org.apache.spark.sql.functions.col;
|
import static org.apache.spark.sql.functions.col;
|
||||||
import static org.datavec.spark.transform.DataRowsFacade.dataRows;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -71,7 +67,7 @@ public class DataFrames {
|
||||||
* deviation for
|
* deviation for
|
||||||
* @return the column that represents the standard deviation
|
* @return the column that represents the standard deviation
|
||||||
*/
|
*/
|
||||||
public static Column std(DataRowsFacade dataFrame, String columnName) {
|
public static Column std(Dataset<Row> dataFrame, String columnName) {
|
||||||
return functions.sqrt(var(dataFrame, columnName));
|
return functions.sqrt(var(dataFrame, columnName));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,8 +81,8 @@ public class DataFrames {
|
||||||
* deviation for
|
* deviation for
|
||||||
* @return the column that represents the standard deviation
|
* @return the column that represents the standard deviation
|
||||||
*/
|
*/
|
||||||
public static Column var(DataRowsFacade dataFrame, String columnName) {
|
public static Column var(Dataset<Row> dataFrame, String columnName) {
|
||||||
return dataFrame.get().groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
|
return dataFrame.groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -97,8 +93,8 @@ public class DataFrames {
|
||||||
* @param columnName the name of the column to get the min for
|
* @param columnName the name of the column to get the min for
|
||||||
* @return the column that represents the min
|
* @return the column that represents the min
|
||||||
*/
|
*/
|
||||||
public static Column min(DataRowsFacade dataFrame, String columnName) {
|
public static Column min(Dataset<Row> dataFrame, String columnName) {
|
||||||
return dataFrame.get().groupBy(columnName).agg(functions.min(columnName)).col(columnName);
|
return dataFrame.groupBy(columnName).agg(functions.min(columnName)).col(columnName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -110,8 +106,8 @@ public class DataFrames {
|
||||||
* to get the max for
|
* to get the max for
|
||||||
* @return the column that represents the max
|
* @return the column that represents the max
|
||||||
*/
|
*/
|
||||||
public static Column max(DataRowsFacade dataFrame, String columnName) {
|
public static Column max(Dataset<Row> dataFrame, String columnName) {
|
||||||
return dataFrame.get().groupBy(columnName).agg(functions.max(columnName)).col(columnName);
|
return dataFrame.groupBy(columnName).agg(functions.max(columnName)).col(columnName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -122,8 +118,8 @@ public class DataFrames {
|
||||||
* @param columnName the name of the column to get the mean for
|
* @param columnName the name of the column to get the mean for
|
||||||
* @return the column that represents the mean
|
* @return the column that represents the mean
|
||||||
*/
|
*/
|
||||||
public static Column mean(DataRowsFacade dataFrame, String columnName) {
|
public static Column mean(Dataset<Row> dataFrame, String columnName) {
|
||||||
return dataFrame.get().groupBy(columnName).agg(avg(columnName)).col(columnName);
|
return dataFrame.groupBy(columnName).agg(avg(columnName)).col(columnName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -166,7 +162,7 @@ public class DataFrames {
|
||||||
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
|
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
|
||||||
* of this record in the original time series.<br>
|
* of this record in the original time series.<br>
|
||||||
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
|
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
|
||||||
* using {@link #toRecordsSequence(DataRowsFacade)}
|
* using {@link #toRecordsSequence(Dataset<Row>)}
|
||||||
*
|
*
|
||||||
* @param schema Schema to convert
|
* @param schema Schema to convert
|
||||||
* @return StructType for the schema
|
* @return StructType for the schema
|
||||||
|
@ -250,9 +246,9 @@ public class DataFrames {
|
||||||
* @param dataFrame the dataframe to convert
|
* @param dataFrame the dataframe to convert
|
||||||
* @return the converted schema and rdd of writables
|
* @return the converted schema and rdd of writables
|
||||||
*/
|
*/
|
||||||
public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(DataRowsFacade dataFrame) {
|
public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(Dataset<Row> dataFrame) {
|
||||||
Schema schema = fromStructType(dataFrame.get().schema());
|
Schema schema = fromStructType(dataFrame.schema());
|
||||||
return new Pair<>(schema, dataFrame.get().javaRDD().map(new ToRecord(schema)));
|
return new Pair<>(schema, dataFrame.javaRDD().map(new ToRecord(schema)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -267,11 +263,11 @@ public class DataFrames {
|
||||||
* @param dataFrame Data frame to convert
|
* @param dataFrame Data frame to convert
|
||||||
* @return Data in sequence (i.e., {@code List<List<Writable>>} form
|
* @return Data in sequence (i.e., {@code List<List<Writable>>} form
|
||||||
*/
|
*/
|
||||||
public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(DataRowsFacade dataFrame) {
|
public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(Dataset<Row> dataFrame) {
|
||||||
|
|
||||||
//Need to convert from flattened to sequence data...
|
//Need to convert from flattened to sequence data...
|
||||||
//First: Group by the Sequence UUID (first column)
|
//First: Group by the Sequence UUID (first column)
|
||||||
JavaPairRDD<String, Iterable<Row>> grouped = dataFrame.get().javaRDD().groupBy(new Function<Row, String>() {
|
JavaPairRDD<String, Iterable<Row>> grouped = dataFrame.javaRDD().groupBy(new Function<Row, String>() {
|
||||||
@Override
|
@Override
|
||||||
public String call(Row row) throws Exception {
|
public String call(Row row) throws Exception {
|
||||||
return row.getString(0);
|
return row.getString(0);
|
||||||
|
@ -279,7 +275,7 @@ public class DataFrames {
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
Schema schema = fromStructType(dataFrame.get().schema());
|
Schema schema = fromStructType(dataFrame.schema());
|
||||||
|
|
||||||
//Group by sequence UUID, and sort each row within the sequences using the time step index
|
//Group by sequence UUID, and sort each row within the sequences using the time step index
|
||||||
Function<Iterable<Row>, List<List<Writable>>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner
|
Function<Iterable<Row>, List<List<Writable>>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner
|
||||||
|
@ -318,11 +314,11 @@ public class DataFrames {
|
||||||
* @param data the data to convert
|
* @param data the data to convert
|
||||||
* @return the dataframe object
|
* @return the dataframe object
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade toDataFrame(Schema schema, JavaRDD<List<Writable>> data) {
|
public static Dataset<Row> toDataFrame(Schema schema, JavaRDD<List<Writable>> data) {
|
||||||
JavaSparkContext sc = new JavaSparkContext(data.context());
|
JavaSparkContext sc = new JavaSparkContext(data.context());
|
||||||
SQLContext sqlContext = new SQLContext(sc);
|
SQLContext sqlContext = new SQLContext(sc);
|
||||||
JavaRDD<Row> rows = data.map(new ToRow(schema));
|
JavaRDD<Row> rows = data.map(new ToRow(schema));
|
||||||
return dataRows(sqlContext.createDataFrame(rows, fromSchema(schema)));
|
return sqlContext.createDataFrame(rows, fromSchema(schema));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -333,18 +329,18 @@ public class DataFrames {
|
||||||
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
|
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
|
||||||
* of this record in the original time series.<br>
|
* of this record in the original time series.<br>
|
||||||
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
|
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
|
||||||
* using {@link #toRecordsSequence(DataRowsFacade)}
|
* using {@link #toRecordsSequence(Dataset<Row>)}
|
||||||
*
|
*
|
||||||
* @param schema Schema for the data
|
* @param schema Schema for the data
|
||||||
* @param data Sequence data to convert to a DataFrame
|
* @param data Sequence data to convert to a DataFrame
|
||||||
* @return The dataframe object
|
* @return The dataframe object
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
|
public static Dataset<Row> toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
|
||||||
JavaSparkContext sc = new JavaSparkContext(data.context());
|
JavaSparkContext sc = new JavaSparkContext(data.context());
|
||||||
|
|
||||||
SQLContext sqlContext = new SQLContext(sc);
|
SQLContext sqlContext = new SQLContext(sc);
|
||||||
JavaRDD<Row> rows = data.flatMap(new SequenceToRows(schema));
|
JavaRDD<Row> rows = data.flatMap(new SequenceToRows(schema));
|
||||||
return dataRows(sqlContext.createDataFrame(rows, fromSchemaSequence(schema)));
|
return sqlContext.createDataFrame(rows, fromSchemaSequence(schema));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -19,14 +19,13 @@ package org.datavec.spark.transform;
|
||||||
import org.apache.commons.collections.map.ListOrderedMap;
|
import org.apache.commons.collections.map.ListOrderedMap;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.sql.Column;
|
import org.apache.spark.sql.Column;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.datavec.spark.transform.DataRowsFacade.dataRows;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Simple dataframe based normalization.
|
* Simple dataframe based normalization.
|
||||||
|
@ -46,7 +45,7 @@ public class Normalization {
|
||||||
* @return a zero mean unit variance centered
|
* @return a zero mean unit variance centered
|
||||||
* rdd
|
* rdd
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame) {
|
public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame) {
|
||||||
return zeromeanUnitVariance(frame, Collections.<String>emptyList());
|
return zeromeanUnitVariance(frame, Collections.<String>emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +70,7 @@ public class Normalization {
|
||||||
* @param max the maximum value
|
* @param max the maximum value
|
||||||
* @return the normalized dataframe per column
|
* @return the normalized dataframe per column
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max) {
|
public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max) {
|
||||||
return normalize(dataFrame, min, max, Collections.<String>emptyList());
|
return normalize(dataFrame, min, max, Collections.<String>emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +85,7 @@ public class Normalization {
|
||||||
*/
|
*/
|
||||||
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min,
|
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min,
|
||||||
double max) {
|
double max) {
|
||||||
DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
|
Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
|
||||||
return DataFrames.toRecords(normalize(frame, min, max, Collections.<String>emptyList())).getSecond();
|
return DataFrames.toRecords(normalize(frame, min, max, Collections.<String>emptyList())).getSecond();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +96,7 @@ public class Normalization {
|
||||||
* @param dataFrame the dataframe to scale
|
* @param dataFrame the dataframe to scale
|
||||||
* @return the normalized dataframe per column
|
* @return the normalized dataframe per column
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame) {
|
public static Dataset<Row> normalize(Dataset<Row> dataFrame) {
|
||||||
return normalize(dataFrame, 0, 1, Collections.<String>emptyList());
|
return normalize(dataFrame, 0, 1, Collections.<String>emptyList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,8 +119,8 @@ public class Normalization {
|
||||||
* @return a zero mean unit variance centered
|
* @return a zero mean unit variance centered
|
||||||
* rdd
|
* rdd
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame, List<String> skipColumns) {
|
public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame, List<String> skipColumns) {
|
||||||
List<String> columnsList = DataFrames.toList(frame.get().columns());
|
List<String> columnsList = DataFrames.toList(frame.columns());
|
||||||
columnsList.removeAll(skipColumns);
|
columnsList.removeAll(skipColumns);
|
||||||
String[] columnNames = DataFrames.toArray(columnsList);
|
String[] columnNames = DataFrames.toArray(columnsList);
|
||||||
//first row is std second row is mean, each column in a row is for a particular column
|
//first row is std second row is mean, each column in a row is for a particular column
|
||||||
|
@ -133,7 +132,7 @@ public class Normalization {
|
||||||
if (std == 0.0)
|
if (std == 0.0)
|
||||||
std = 1; //All same value -> (x-x)/1 = 0
|
std = 1; //All same value -> (x-x)/1 = 0
|
||||||
|
|
||||||
frame = dataRows(frame.get().withColumn(columnName, frame.get().col(columnName).minus(mean).divide(std)));
|
frame = frame.withColumn(columnName, frame.col(columnName).minus(mean).divide(std));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,7 +151,7 @@ public class Normalization {
|
||||||
*/
|
*/
|
||||||
public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> data,
|
public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> data,
|
||||||
List<String> skipColumns) {
|
List<String> skipColumns) {
|
||||||
DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
|
Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
|
||||||
return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond();
|
return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +177,7 @@ public class Normalization {
|
||||||
*/
|
*/
|
||||||
public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema,
|
public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema,
|
||||||
JavaRDD<List<List<Writable>>> sequence, List<String> excludeColumns) {
|
JavaRDD<List<List<Writable>>> sequence, List<String> excludeColumns) {
|
||||||
DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, sequence);
|
Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, sequence);
|
||||||
if (excludeColumns == null)
|
if (excludeColumns == null)
|
||||||
excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN);
|
excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN);
|
||||||
else {
|
else {
|
||||||
|
@ -196,7 +195,7 @@ public class Normalization {
|
||||||
* @param columns the columns to get the
|
* @param columns the columns to get the
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<Row> minMaxColumns(DataRowsFacade data, List<String> columns) {
|
public static List<Row> minMaxColumns(Dataset<Row> data, List<String> columns) {
|
||||||
String[] arr = new String[columns.size()];
|
String[] arr = new String[columns.size()];
|
||||||
for (int i = 0; i < arr.length; i++)
|
for (int i = 0; i < arr.length; i++)
|
||||||
arr[i] = columns.get(i);
|
arr[i] = columns.get(i);
|
||||||
|
@ -210,7 +209,7 @@ public class Normalization {
|
||||||
* @param columns the columns to get the
|
* @param columns the columns to get the
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<Row> minMaxColumns(DataRowsFacade data, String... columns) {
|
public static List<Row> minMaxColumns(Dataset<Row> data, String... columns) {
|
||||||
return aggregate(data, columns, new String[] {"min", "max"});
|
return aggregate(data, columns, new String[] {"min", "max"});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,7 +220,7 @@ public class Normalization {
|
||||||
* @param columns the columns to get the
|
* @param columns the columns to get the
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<Row> stdDevMeanColumns(DataRowsFacade data, List<String> columns) {
|
public static List<Row> stdDevMeanColumns(Dataset<Row> data, List<String> columns) {
|
||||||
String[] arr = new String[columns.size()];
|
String[] arr = new String[columns.size()];
|
||||||
for (int i = 0; i < arr.length; i++)
|
for (int i = 0; i < arr.length; i++)
|
||||||
arr[i] = columns.get(i);
|
arr[i] = columns.get(i);
|
||||||
|
@ -237,7 +236,7 @@ public class Normalization {
|
||||||
* @param columns the columns to get the
|
* @param columns the columns to get the
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<Row> stdDevMeanColumns(DataRowsFacade data, String... columns) {
|
public static List<Row> stdDevMeanColumns(Dataset<Row> data, String... columns) {
|
||||||
return aggregate(data, columns, new String[] {"stddev", "mean"});
|
return aggregate(data, columns, new String[] {"stddev", "mean"});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,7 +250,7 @@ public class Normalization {
|
||||||
* Each row will be a function with the desired columnar output
|
* Each row will be a function with the desired columnar output
|
||||||
* in the order in which the columns were specified.
|
* in the order in which the columns were specified.
|
||||||
*/
|
*/
|
||||||
public static List<Row> aggregate(DataRowsFacade data, String[] columns, String[] functions) {
|
public static List<Row> aggregate(Dataset<Row> data, String[] columns, String[] functions) {
|
||||||
String[] rest = new String[columns.length - 1];
|
String[] rest = new String[columns.length - 1];
|
||||||
System.arraycopy(columns, 1, rest, 0, rest.length);
|
System.arraycopy(columns, 1, rest, 0, rest.length);
|
||||||
List<Row> rows = new ArrayList<>();
|
List<Row> rows = new ArrayList<>();
|
||||||
|
@ -262,8 +261,8 @@ public class Normalization {
|
||||||
}
|
}
|
||||||
|
|
||||||
//compute the aggregation based on the operation
|
//compute the aggregation based on the operation
|
||||||
DataRowsFacade aggregated = dataRows(data.get().agg(expressions));
|
Dataset<Row> aggregated = data.agg(expressions);
|
||||||
String[] columns2 = aggregated.get().columns();
|
String[] columns2 = aggregated.columns();
|
||||||
//strip out the op name and parentheses from the columns
|
//strip out the op name and parentheses from the columns
|
||||||
Map<String, String> opReplace = new TreeMap<>();
|
Map<String, String> opReplace = new TreeMap<>();
|
||||||
for (String s : columns2) {
|
for (String s : columns2) {
|
||||||
|
@ -278,20 +277,20 @@ public class Normalization {
|
||||||
|
|
||||||
|
|
||||||
//get rid of the operation name in the column
|
//get rid of the operation name in the column
|
||||||
DataRowsFacade rearranged = null;
|
Dataset<Row> rearranged = null;
|
||||||
for (Map.Entry<String, String> entries : opReplace.entrySet()) {
|
for (Map.Entry<String, String> entries : opReplace.entrySet()) {
|
||||||
//first column
|
//first column
|
||||||
if (rearranged == null) {
|
if (rearranged == null) {
|
||||||
rearranged = dataRows(aggregated.get().withColumnRenamed(entries.getKey(), entries.getValue()));
|
rearranged = aggregated.withColumnRenamed(entries.getKey(), entries.getValue());
|
||||||
}
|
}
|
||||||
//rearranged is just a copy of aggregated at this point
|
//rearranged is just a copy of aggregated at this point
|
||||||
else
|
else
|
||||||
rearranged = dataRows(rearranged.get().withColumnRenamed(entries.getKey(), entries.getValue()));
|
rearranged = rearranged.withColumnRenamed(entries.getKey(), entries.getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
rearranged = dataRows(rearranged.get().select(DataFrames.toColumns(columns)));
|
rearranged = rearranged.select(DataFrames.toColumns(columns));
|
||||||
//op
|
//op
|
||||||
rows.addAll(rearranged.get().collectAsList());
|
rows.addAll(rearranged.collectAsList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -307,8 +306,8 @@ public class Normalization {
|
||||||
* @param max the maximum value
|
* @param max the maximum value
|
||||||
* @return the normalized dataframe per column
|
* @return the normalized dataframe per column
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max, List<String> skipColumns) {
|
public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max, List<String> skipColumns) {
|
||||||
List<String> columnsList = DataFrames.toList(dataFrame.get().columns());
|
List<String> columnsList = DataFrames.toList(dataFrame.columns());
|
||||||
columnsList.removeAll(skipColumns);
|
columnsList.removeAll(skipColumns);
|
||||||
String[] columnNames = DataFrames.toArray(columnsList);
|
String[] columnNames = DataFrames.toArray(columnsList);
|
||||||
//first row is min second row is max, each column in a row is for a particular column
|
//first row is min second row is max, each column in a row is for a particular column
|
||||||
|
@ -321,8 +320,8 @@ public class Normalization {
|
||||||
if (maxSubMin == 0)
|
if (maxSubMin == 0)
|
||||||
maxSubMin = 1;
|
maxSubMin = 1;
|
||||||
|
|
||||||
Column newCol = dataFrame.get().col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min);
|
Column newCol = dataFrame.col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min);
|
||||||
dataFrame = dataRows(dataFrame.get().withColumn(columnName, newCol));
|
dataFrame = dataFrame.withColumn(columnName, newCol);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -340,7 +339,7 @@ public class Normalization {
|
||||||
*/
|
*/
|
||||||
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, double max,
|
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, double max,
|
||||||
List<String> skipColumns) {
|
List<String> skipColumns) {
|
||||||
DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
|
Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
|
||||||
return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond();
|
return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -387,7 +386,7 @@ public class Normalization {
|
||||||
excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN);
|
excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN);
|
||||||
excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN);
|
excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN);
|
||||||
}
|
}
|
||||||
DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, data);
|
Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, data);
|
||||||
return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond();
|
return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -398,7 +397,7 @@ public class Normalization {
|
||||||
* @param dataFrame the dataframe to scale
|
* @param dataFrame the dataframe to scale
|
||||||
* @return the normalized dataframe per column
|
* @return the normalized dataframe per column
|
||||||
*/
|
*/
|
||||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame, List<String> skipColumns) {
|
public static Dataset<Row> normalize(Dataset<Row> dataFrame, List<String> skipColumns) {
|
||||||
return normalize(dataFrame, 0, 1, skipColumns);
|
return normalize(dataFrame, 0, 1, skipColumns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,10 @@
|
||||||
|
|
||||||
package org.datavec.spark.transform.analysis;
|
package org.datavec.spark.transform.analysis;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -27,10 +28,11 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class SequenceFlatMapFunction extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, List<Writable>> {
|
public class SequenceFlatMapFunction implements FlatMapFunction<List<List<Writable>>, List<Writable>> {
|
||||||
|
|
||||||
public SequenceFlatMapFunction() {
|
@Override
|
||||||
super(new SequenceFlatMapFunctionAdapter());
|
public Iterator<List<Writable>> call(List<List<Writable>> collections) throws Exception {
|
||||||
|
return collections.iterator();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform.analysis;
|
|
||||||
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SequenceFlatMapFunction: very simple function used to flatten a sequence
|
|
||||||
* Typically used only internally for certain analysis operations
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class SequenceFlatMapFunctionAdapter implements FlatMapFunctionAdapter<List<List<Writable>>, List<Writable>> {
|
|
||||||
@Override
|
|
||||||
public Iterable<List<Writable>> call(List<List<Writable>> collections) throws Exception {
|
|
||||||
return collections;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -16,11 +16,14 @@
|
||||||
|
|
||||||
package org.datavec.spark.transform.join;
|
package org.datavec.spark.transform.join;
|
||||||
|
|
||||||
|
import org.nd4j.shade.guava.collect.Iterables;
|
||||||
|
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||||
import org.datavec.api.transform.join.Join;
|
import org.datavec.api.transform.join.Join;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -28,10 +31,89 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class ExecuteJoinFromCoGroupFlatMapFunction extends
|
public class ExecuteJoinFromCoGroupFlatMapFunction implements FlatMapFunction<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
|
||||||
BaseFlatMapFunctionAdaptee<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
|
|
||||||
|
private final Join join;
|
||||||
|
|
||||||
public ExecuteJoinFromCoGroupFlatMapFunction(Join join) {
|
public ExecuteJoinFromCoGroupFlatMapFunction(Join join) {
|
||||||
super(new ExecuteJoinFromCoGroupFlatMapFunctionAdapter(join));
|
this.join = join;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<List<Writable>> call(
|
||||||
|
Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>> t2)
|
||||||
|
throws Exception {
|
||||||
|
|
||||||
|
Iterable<List<Writable>> leftList = t2._2()._1();
|
||||||
|
Iterable<List<Writable>> rightList = t2._2()._2();
|
||||||
|
|
||||||
|
List<List<Writable>> ret = new ArrayList<>();
|
||||||
|
Join.JoinType jt = join.getJoinType();
|
||||||
|
switch (jt) {
|
||||||
|
case Inner:
|
||||||
|
//Return records where key columns appear in BOTH
|
||||||
|
//So if no values from left OR right: no return values
|
||||||
|
for (List<Writable> jvl : leftList) {
|
||||||
|
for (List<Writable> jvr : rightList) {
|
||||||
|
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||||
|
ret.add(joined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case LeftOuter:
|
||||||
|
//Return all records from left, even if no corresponding right value (NullWritable in that case)
|
||||||
|
for (List<Writable> jvl : leftList) {
|
||||||
|
if (Iterables.size(rightList) == 0) {
|
||||||
|
List<Writable> joined = join.joinExamples(jvl, null);
|
||||||
|
ret.add(joined);
|
||||||
|
} else {
|
||||||
|
for (List<Writable> jvr : rightList) {
|
||||||
|
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||||
|
ret.add(joined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case RightOuter:
|
||||||
|
//Return all records from right, even if no corresponding left value (NullWritable in that case)
|
||||||
|
for (List<Writable> jvr : rightList) {
|
||||||
|
if (Iterables.size(leftList) == 0) {
|
||||||
|
List<Writable> joined = join.joinExamples(null, jvr);
|
||||||
|
ret.add(joined);
|
||||||
|
} else {
|
||||||
|
for (List<Writable> jvl : leftList) {
|
||||||
|
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||||
|
ret.add(joined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case FullOuter:
|
||||||
|
//Return all records, even if no corresponding left/right value (NullWritable in that case)
|
||||||
|
if (Iterables.size(leftList) == 0) {
|
||||||
|
//Only right values
|
||||||
|
for (List<Writable> jvr : rightList) {
|
||||||
|
List<Writable> joined = join.joinExamples(null, jvr);
|
||||||
|
ret.add(joined);
|
||||||
|
}
|
||||||
|
} else if (Iterables.size(rightList) == 0) {
|
||||||
|
//Only left values
|
||||||
|
for (List<Writable> jvl : leftList) {
|
||||||
|
List<Writable> joined = join.joinExamples(jvl, null);
|
||||||
|
ret.add(joined);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//Records from both left and right
|
||||||
|
for (List<Writable> jvl : leftList) {
|
||||||
|
for (List<Writable> jvr : rightList) {
|
||||||
|
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||||
|
ret.add(joined);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret.iterator();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,119 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform.join;
|
|
||||||
|
|
||||||
import com.google.common.collect.Iterables;
|
|
||||||
import org.datavec.api.transform.join.Join;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
import scala.Tuple2;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Execute a join
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class ExecuteJoinFromCoGroupFlatMapFunctionAdapter implements
|
|
||||||
FlatMapFunctionAdapter<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
|
|
||||||
|
|
||||||
private final Join join;
|
|
||||||
|
|
||||||
public ExecuteJoinFromCoGroupFlatMapFunctionAdapter(Join join) {
|
|
||||||
this.join = join;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterable<List<Writable>> call(
|
|
||||||
Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>> t2)
|
|
||||||
throws Exception {
|
|
||||||
|
|
||||||
Iterable<List<Writable>> leftList = t2._2()._1();
|
|
||||||
Iterable<List<Writable>> rightList = t2._2()._2();
|
|
||||||
|
|
||||||
List<List<Writable>> ret = new ArrayList<>();
|
|
||||||
Join.JoinType jt = join.getJoinType();
|
|
||||||
switch (jt) {
|
|
||||||
case Inner:
|
|
||||||
//Return records where key columns appear in BOTH
|
|
||||||
//So if no values from left OR right: no return values
|
|
||||||
for (List<Writable> jvl : leftList) {
|
|
||||||
for (List<Writable> jvr : rightList) {
|
|
||||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
|
||||||
ret.add(joined);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case LeftOuter:
|
|
||||||
//Return all records from left, even if no corresponding right value (NullWritable in that case)
|
|
||||||
for (List<Writable> jvl : leftList) {
|
|
||||||
if (Iterables.size(rightList) == 0) {
|
|
||||||
List<Writable> joined = join.joinExamples(jvl, null);
|
|
||||||
ret.add(joined);
|
|
||||||
} else {
|
|
||||||
for (List<Writable> jvr : rightList) {
|
|
||||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
|
||||||
ret.add(joined);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case RightOuter:
|
|
||||||
//Return all records from right, even if no corresponding left value (NullWritable in that case)
|
|
||||||
for (List<Writable> jvr : rightList) {
|
|
||||||
if (Iterables.size(leftList) == 0) {
|
|
||||||
List<Writable> joined = join.joinExamples(null, jvr);
|
|
||||||
ret.add(joined);
|
|
||||||
} else {
|
|
||||||
for (List<Writable> jvl : leftList) {
|
|
||||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
|
||||||
ret.add(joined);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case FullOuter:
|
|
||||||
//Return all records, even if no corresponding left/right value (NullWritable in that case)
|
|
||||||
if (Iterables.size(leftList) == 0) {
|
|
||||||
//Only right values
|
|
||||||
for (List<Writable> jvr : rightList) {
|
|
||||||
List<Writable> joined = join.joinExamples(null, jvr);
|
|
||||||
ret.add(joined);
|
|
||||||
}
|
|
||||||
} else if (Iterables.size(rightList) == 0) {
|
|
||||||
//Only left values
|
|
||||||
for (List<Writable> jvl : leftList) {
|
|
||||||
List<Writable> joined = join.joinExamples(jvl, null);
|
|
||||||
ret.add(joined);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
//Records from both left and right
|
|
||||||
for (List<Writable> jvl : leftList) {
|
|
||||||
for (List<Writable> jvr : rightList) {
|
|
||||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
|
||||||
ret.add(joined);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.datavec.spark.transform.join;
|
package org.datavec.spark.transform.join;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||||
import org.datavec.api.transform.join.Join;
|
import org.datavec.api.transform.join.Join;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -29,10 +31,43 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class FilterAndFlattenJoinedValues extends BaseFlatMapFunctionAdaptee<JoinedValue, List<Writable>> {
|
public class FilterAndFlattenJoinedValues implements FlatMapFunction<JoinedValue, List<Writable>> {
|
||||||
|
|
||||||
|
private final Join.JoinType joinType;
|
||||||
|
|
||||||
public FilterAndFlattenJoinedValues(Join.JoinType joinType) {
|
public FilterAndFlattenJoinedValues(Join.JoinType joinType) {
|
||||||
super(new FilterAndFlattenJoinedValuesAdapter(joinType));
|
this.joinType = joinType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<List<Writable>> call(JoinedValue joinedValue) throws Exception {
|
||||||
|
boolean keep;
|
||||||
|
switch (joinType) {
|
||||||
|
case Inner:
|
||||||
|
//Only keep joined values where we have both left and right
|
||||||
|
keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
|
||||||
|
break;
|
||||||
|
case LeftOuter:
|
||||||
|
//Keep all values where left is not missing/null
|
||||||
|
keep = joinedValue.isHaveLeft();
|
||||||
|
break;
|
||||||
|
case RightOuter:
|
||||||
|
//Keep all values where right is not missing/null
|
||||||
|
keep = joinedValue.isHaveRight();
|
||||||
|
break;
|
||||||
|
case FullOuter:
|
||||||
|
//Keep all values
|
||||||
|
keep = true;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Unknown/not implemented join type: " + joinType);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keep) {
|
||||||
|
return Collections.singletonList(joinedValue.getValues()).iterator();
|
||||||
|
} else {
|
||||||
|
return Collections.emptyIterator();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform.join;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.join.Join;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Doing two things here:
|
|
||||||
* (a) filter out any unnecessary values, and
|
|
||||||
* (b) extract the List<Writable> values from the JoinedValue
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class FilterAndFlattenJoinedValuesAdapter implements FlatMapFunctionAdapter<JoinedValue, List<Writable>> {
|
|
||||||
|
|
||||||
private final Join.JoinType joinType;
|
|
||||||
|
|
||||||
public FilterAndFlattenJoinedValuesAdapter(Join.JoinType joinType) {
|
|
||||||
this.joinType = joinType;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterable<List<Writable>> call(JoinedValue joinedValue) throws Exception {
|
|
||||||
boolean keep;
|
|
||||||
switch (joinType) {
|
|
||||||
case Inner:
|
|
||||||
//Only keep joined values where we have both left and right
|
|
||||||
keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
|
|
||||||
break;
|
|
||||||
case LeftOuter:
|
|
||||||
//Keep all values where left is not missing/null
|
|
||||||
keep = joinedValue.isHaveLeft();
|
|
||||||
break;
|
|
||||||
case RightOuter:
|
|
||||||
//Keep all values where right is not missing/null
|
|
||||||
keep = joinedValue.isHaveRight();
|
|
||||||
break;
|
|
||||||
case FullOuter:
|
|
||||||
//Keep all values
|
|
||||||
keep = true;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new RuntimeException("Unknown/not implemented join type: " + joinType);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (keep) {
|
|
||||||
return Collections.singletonList(joinedValue.getValues());
|
|
||||||
} else {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,21 +16,69 @@
|
||||||
|
|
||||||
package org.datavec.spark.transform.sparkfunction;
|
package org.datavec.spark.transform.sparkfunction;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
import org.datavec.spark.transform.DataFrames;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert a record to a row
|
* Convert a record to a row
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
public class SequenceToRows extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, Row> {
|
public class SequenceToRows implements FlatMapFunction<List<List<Writable>>, Row> {
|
||||||
|
|
||||||
|
private Schema schema;
|
||||||
|
private StructType structType;
|
||||||
|
|
||||||
public SequenceToRows(Schema schema) {
|
public SequenceToRows(Schema schema) {
|
||||||
super(new SequenceToRowsAdapter(schema));
|
this.schema = schema;
|
||||||
|
structType = DataFrames.fromSchemaSequence(schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<Row> call(List<List<Writable>> sequence) throws Exception {
|
||||||
|
if (sequence.size() == 0)
|
||||||
|
return Collections.emptyIterator();
|
||||||
|
|
||||||
|
String sequenceUUID = UUID.randomUUID().toString();
|
||||||
|
|
||||||
|
List<Row> out = new ArrayList<>(sequence.size());
|
||||||
|
|
||||||
|
int stepCount = 0;
|
||||||
|
for (List<Writable> step : sequence) {
|
||||||
|
Object[] values = new Object[step.size() + 2];
|
||||||
|
values[0] = sequenceUUID;
|
||||||
|
values[1] = stepCount++;
|
||||||
|
for (int i = 0; i < step.size(); i++) {
|
||||||
|
switch (schema.getColumnTypes().get(i)) {
|
||||||
|
case Double:
|
||||||
|
values[i + 2] = step.get(i).toDouble();
|
||||||
|
break;
|
||||||
|
case Integer:
|
||||||
|
values[i + 2] = step.get(i).toInt();
|
||||||
|
break;
|
||||||
|
case Long:
|
||||||
|
values[i + 2] = step.get(i).toLong();
|
||||||
|
break;
|
||||||
|
case Float:
|
||||||
|
values[i + 2] = step.get(i).toFloat();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Row row = new GenericRowWithSchema(values, structType);
|
||||||
|
out.add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.iterator();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform.sparkfunction;
|
|
||||||
|
|
||||||
import org.apache.spark.sql.Row;
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
|
|
||||||
import org.apache.spark.sql.types.StructType;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
import org.datavec.spark.transform.DataFrames;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a record to a row
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class SequenceToRowsAdapter implements FlatMapFunctionAdapter<List<List<Writable>>, Row> {
|
|
||||||
|
|
||||||
private Schema schema;
|
|
||||||
private StructType structType;
|
|
||||||
|
|
||||||
public SequenceToRowsAdapter(Schema schema) {
|
|
||||||
this.schema = schema;
|
|
||||||
structType = DataFrames.fromSchemaSequence(schema);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterable<Row> call(List<List<Writable>> sequence) throws Exception {
|
|
||||||
if (sequence.size() == 0)
|
|
||||||
return Collections.emptyList();
|
|
||||||
|
|
||||||
String sequenceUUID = UUID.randomUUID().toString();
|
|
||||||
|
|
||||||
List<Row> out = new ArrayList<>(sequence.size());
|
|
||||||
|
|
||||||
int stepCount = 0;
|
|
||||||
for (List<Writable> step : sequence) {
|
|
||||||
Object[] values = new Object[step.size() + 2];
|
|
||||||
values[0] = sequenceUUID;
|
|
||||||
values[1] = stepCount++;
|
|
||||||
for (int i = 0; i < step.size(); i++) {
|
|
||||||
switch (schema.getColumnTypes().get(i)) {
|
|
||||||
case Double:
|
|
||||||
values[i + 2] = step.get(i).toDouble();
|
|
||||||
break;
|
|
||||||
case Integer:
|
|
||||||
values[i + 2] = step.get(i).toInt();
|
|
||||||
break;
|
|
||||||
case Long:
|
|
||||||
values[i + 2] = step.get(i).toLong();
|
|
||||||
break;
|
|
||||||
case Float:
|
|
||||||
values[i + 2] = step.get(i).toFloat();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Row row = new GenericRowWithSchema(values, structType);
|
|
||||||
out.add(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,19 +16,27 @@
|
||||||
|
|
||||||
package org.datavec.spark.transform.transform;
|
package org.datavec.spark.transform.transform;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
import org.datavec.api.transform.sequence.SequenceSplit;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
|
||||||
|
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 17/03/2016.
|
* Created by Alex on 17/03/2016.
|
||||||
*/
|
*/
|
||||||
public class SequenceSplitFunction extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, List<List<Writable>>> {
|
public class SequenceSplitFunction implements FlatMapFunction<List<List<Writable>>, List<List<Writable>>> {
|
||||||
|
|
||||||
|
private final SequenceSplit split;
|
||||||
|
|
||||||
public SequenceSplitFunction(SequenceSplit split) {
|
public SequenceSplitFunction(SequenceSplit split) {
|
||||||
super(new SequenceSplitFunctionAdapter(split));
|
this.split = split;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<List<List<Writable>>> call(List<List<Writable>> collections) throws Exception {
|
||||||
|
return split.split(collections).iterator();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 17/03/2016.
|
|
||||||
*/
|
|
||||||
public class SequenceSplitFunctionAdapter
|
|
||||||
implements FlatMapFunctionAdapter<List<List<Writable>>, List<List<Writable>>> {
|
|
||||||
|
|
||||||
private final SequenceSplit split;
|
|
||||||
|
|
||||||
public SequenceSplitFunctionAdapter(SequenceSplit split) {
|
|
||||||
this.split = split;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterable<List<List<Writable>>> call(List<List<Writable>> collections) throws Exception {
|
|
||||||
return split.split(collections);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,19 +16,32 @@
|
||||||
|
|
||||||
package org.datavec.spark.transform.transform;
|
package org.datavec.spark.transform.transform;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||||
import org.datavec.api.transform.TransformProcess;
|
import org.datavec.api.transform.TransformProcess;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Spark function for executing a transform process
|
* Spark function for executing a transform process
|
||||||
*/
|
*/
|
||||||
public class SparkTransformProcessFunction extends BaseFlatMapFunctionAdaptee<List<Writable>, List<Writable>> {
|
public class SparkTransformProcessFunction implements FlatMapFunction<List<Writable>, List<Writable>> {
|
||||||
|
|
||||||
|
private final TransformProcess transformProcess;
|
||||||
|
|
||||||
public SparkTransformProcessFunction(TransformProcess transformProcess) {
|
public SparkTransformProcessFunction(TransformProcess transformProcess) {
|
||||||
super(new SparkTransformProcessFunctionAdapter(transformProcess));
|
this.transformProcess = transformProcess;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<List<Writable>> call(List<Writable> v1) throws Exception {
|
||||||
|
List<Writable> newList = transformProcess.execute(v1);
|
||||||
|
if (newList == null)
|
||||||
|
return Collections.emptyIterator(); //Example was filtered out
|
||||||
|
else
|
||||||
|
return Collections.singletonList(newList).iterator();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,45 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Spark function for executing a transform process
|
|
||||||
*/
|
|
||||||
public class SparkTransformProcessFunctionAdapter implements FlatMapFunctionAdapter<List<Writable>, List<Writable>> {
|
|
||||||
|
|
||||||
private final TransformProcess transformProcess;
|
|
||||||
|
|
||||||
public SparkTransformProcessFunctionAdapter(TransformProcess transformProcess) {
|
|
||||||
this.transformProcess = transformProcess;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterable<List<Writable>> call(List<Writable> v1) throws Exception {
|
|
||||||
List<Writable> newList = transformProcess.execute(v1);
|
|
||||||
if (newList == null)
|
|
||||||
return Collections.emptyList(); //Example was filtered out
|
|
||||||
else
|
|
||||||
return Collections.singletonList(newList);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* FlatMapFunction adapter to
|
|
||||||
* hide incompatibilities between Spark 1.x and Spark 2.x
|
|
||||||
*
|
|
||||||
* This class should be used instead of direct referral to FlatMapFunction
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class BaseFlatMapFunctionAdaptee<K, V> implements FlatMapFunction<K, V> {
|
|
||||||
|
|
||||||
protected final FlatMapFunctionAdapter<K, V> adapter;
|
|
||||||
|
|
||||||
public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter<K, V> adapter) {
|
|
||||||
this.adapter = adapter;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterable<V> call(K k) throws Exception {
|
|
||||||
return adapter.call(k);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.apache.spark.sql.DataFrame;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
|
|
||||||
*
|
|
||||||
* This class should be used instead of direct referral to DataFrame / Dataset
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class DataRowsFacade {
|
|
||||||
|
|
||||||
private final DataFrame df;
|
|
||||||
|
|
||||||
private DataRowsFacade(DataFrame df) {
|
|
||||||
this.df = df;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static DataRowsFacade dataRows(DataFrame df) {
|
|
||||||
return new DataRowsFacade(df);
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataFrame get() {
|
|
||||||
return df;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
|
||||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
|
||||||
|
|
||||||
import java.util.Iterator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* FlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x
|
|
||||||
*
|
|
||||||
* This class should be used instead of direct referral to FlatMapFunction
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class BaseFlatMapFunctionAdaptee<K, V> implements FlatMapFunction<K, V> {
|
|
||||||
|
|
||||||
protected final FlatMapFunctionAdapter<K, V> adapter;
|
|
||||||
|
|
||||||
public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter<K, V> adapter) {
|
|
||||||
this.adapter = adapter;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterator<V> call(K k) throws Exception {
|
|
||||||
return adapter.call(k).iterator();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.apache.spark.sql.Dataset;
|
|
||||||
import org.apache.spark.sql.Row;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
|
|
||||||
*
|
|
||||||
* This class should be used instead of direct referral to DataFrame / Dataset
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class DataRowsFacade {
|
|
||||||
|
|
||||||
private final Dataset<Row> df;
|
|
||||||
|
|
||||||
private DataRowsFacade(Dataset<Row> df) {
|
|
||||||
this.df = df;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static DataRowsFacade dataRows(Dataset<Row> df) {
|
|
||||||
return new DataRowsFacade(df);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Dataset<Row> get() {
|
|
||||||
return df;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.spark.storage;
|
package org.datavec.spark.storage;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.apache.spark.api.java.JavaPairRDD;
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
|
|
|
@ -19,6 +19,8 @@ package org.datavec.spark.transform;
|
||||||
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
|
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.sql.Column;
|
import org.apache.spark.sql.Column;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
import org.datavec.api.writable.DoubleWritable;
|
||||||
|
@ -46,9 +48,9 @@ public class DataFramesTests extends BaseSparkTest {
|
||||||
for (int i = 0; i < numColumns; i++)
|
for (int i = 0; i < numColumns; i++)
|
||||||
builder.addColumnDouble(String.valueOf(i));
|
builder.addColumnDouble(String.valueOf(i));
|
||||||
Schema schema = builder.build();
|
Schema schema = builder.build();
|
||||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records));
|
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records));
|
||||||
dataFrame.get().show();
|
dataFrame.show();
|
||||||
dataFrame.get().describe(DataFrames.toArray(schema.getColumnNames())).show();
|
dataFrame.describe(DataFrames.toArray(schema.getColumnNames())).show();
|
||||||
// System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames()));
|
// System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames()));
|
||||||
// System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames()));
|
// System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames()));
|
||||||
|
|
||||||
|
@ -77,12 +79,12 @@ public class DataFramesTests extends BaseSparkTest {
|
||||||
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
|
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
|
||||||
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
|
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
|
||||||
|
|
||||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
|
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||||
dataFrame.get().show();
|
dataFrame.show();
|
||||||
Column mean = DataFrames.mean(dataFrame, "0");
|
Column mean = DataFrames.mean(dataFrame, "0");
|
||||||
Column std = DataFrames.std(dataFrame, "0");
|
Column std = DataFrames.std(dataFrame, "0");
|
||||||
dataFrame.get().withColumn("0", dataFrame.get().col("0").minus(mean)).show();
|
dataFrame.withColumn("0", dataFrame.col("0").minus(mean)).show();
|
||||||
dataFrame.get().withColumn("0", dataFrame.get().col("0").divide(std)).show();
|
dataFrame.withColumn("0", dataFrame.col("0").divide(std)).show();
|
||||||
|
|
||||||
/* DataFrame desc = dataFrame.describe(dataFrame.columns());
|
/* DataFrame desc = dataFrame.describe(dataFrame.columns());
|
||||||
dataFrame.show();
|
dataFrame.show();
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.datavec.spark.transform;
|
package org.datavec.spark.transform;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
import org.datavec.api.util.ndarray.RecordConverter;
|
||||||
|
@ -24,11 +25,13 @@ import org.datavec.api.writable.DoubleWritable;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.datavec.spark.BaseSparkTest;
|
import org.datavec.spark.BaseSparkTest;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -50,36 +53,35 @@ public class NormalizationTests extends BaseSparkTest {
|
||||||
for (int i = 0; i < numColumns; i++)
|
for (int i = 0; i < numColumns; i++)
|
||||||
builder.addColumnDouble(String.valueOf(i));
|
builder.addColumnDouble(String.valueOf(i));
|
||||||
|
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
INDArray arr = Nd4j.rand(DataType.FLOAT, 5, numColumns);
|
||||||
for (int i = 0; i < 5; i++) {
|
for (int i = 0; i < 5; i++) {
|
||||||
List<Writable> record = new ArrayList<>(numColumns);
|
List<Writable> record = new ArrayList<>(numColumns);
|
||||||
data.add(record);
|
data.add(record);
|
||||||
for (int j = 0; j < numColumns; j++) {
|
for (int j = 0; j < numColumns; j++) {
|
||||||
record.add(new DoubleWritable(1.0));
|
record.add(new DoubleWritable(arr.getDouble(i, j)));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray arr = RecordConverter.toMatrix(data);
|
|
||||||
|
|
||||||
Schema schema = builder.build();
|
Schema schema = builder.build();
|
||||||
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
|
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
|
||||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
|
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||||
|
|
||||||
//assert equivalent to the ndarray pre processing
|
//assert equivalent to the ndarray pre processing
|
||||||
NormalizerStandardize standardScaler = new NormalizerStandardize();
|
|
||||||
standardScaler.fit(new DataSet(arr.dup(), arr.dup()));
|
|
||||||
INDArray standardScalered = arr.dup();
|
|
||||||
standardScaler.transform(new DataSet(standardScalered, standardScalered));
|
|
||||||
DataNormalization zeroToOne = new NormalizerMinMaxScaler();
|
DataNormalization zeroToOne = new NormalizerMinMaxScaler();
|
||||||
zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
|
zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
|
||||||
INDArray zeroToOnes = arr.dup();
|
INDArray zeroToOnes = arr.dup();
|
||||||
zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
|
zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
|
||||||
List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.get().columns());
|
List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.columns());
|
||||||
INDArray assertion = DataFrames.toMatrix(rows);
|
INDArray assertion = DataFrames.toMatrix(rows);
|
||||||
//compare standard deviation
|
INDArray expStd = arr.std(true, true, 0);
|
||||||
assertTrue(standardScaler.getStd().equalsWithEps(assertion.getRow(0), 1e-1));
|
INDArray std = assertion.getRow(0, true);
|
||||||
|
assertTrue(expStd.equalsWithEps(std, 1e-3));
|
||||||
//compare mean
|
//compare mean
|
||||||
assertTrue(standardScaler.getMean().equalsWithEps(assertion.getRow(1), 1e-1));
|
INDArray expMean = arr.mean(true, 0);
|
||||||
|
assertTrue(expMean.equalsWithEps(assertion.getRow(1, true), 1e-3));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,10 +111,10 @@ public class NormalizationTests extends BaseSparkTest {
|
||||||
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
|
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
|
||||||
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
|
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
|
||||||
|
|
||||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
|
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||||
dataFrame.get().show();
|
dataFrame.show();
|
||||||
Normalization.zeromeanUnitVariance(dataFrame).get().show();
|
Normalization.zeromeanUnitVariance(dataFrame).show();
|
||||||
Normalization.normalize(dataFrame).get().show();
|
Normalization.normalize(dataFrame).show();
|
||||||
|
|
||||||
//assert equivalent to the ndarray pre processing
|
//assert equivalent to the ndarray pre processing
|
||||||
NormalizerStandardize standardScaler = new NormalizerStandardize();
|
NormalizerStandardize standardScaler = new NormalizerStandardize();
|
||||||
|
|
|
@ -52,18 +52,6 @@ public class DL4JSystemProperties {
|
||||||
*/
|
*/
|
||||||
public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl";
|
public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl";
|
||||||
|
|
||||||
/**
|
|
||||||
* Applicability: deeplearning4j-nn<br>
|
|
||||||
* Description: Used for loading legacy format JSON containing custom layers. This system property is provided as an
|
|
||||||
* alternative to {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}. Classes are specified in
|
|
||||||
* comma-separated format.<br>
|
|
||||||
* This is required ONLY when ALL of the following conditions are met:<br>
|
|
||||||
* 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND<br>
|
|
||||||
* 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J), AND<br>
|
|
||||||
* 3. You haven't already called {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}
|
|
||||||
*/
|
|
||||||
public static final String CUSTOM_REGISTRATION_PROPERTY = "org.deeplearning4j.config.custom.legacyclasses";
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Applicability: deeplearning4j-nn<br>
|
* Applicability: deeplearning4j-nn<br>
|
||||||
* Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled
|
* Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled
|
||||||
|
|
|
@ -96,18 +96,6 @@
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.guava</groupId>
|
|
||||||
<artifactId>guava</artifactId>
|
|
||||||
<version>${guava.version}</version>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
|
||||||
<artifactId>jsr305</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-api</artifactId>
|
<artifactId>nd4j-api</artifactId>
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets.datavec;
|
package org.deeplearning4j.datasets.datavec;
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.deeplearning4j.datasets.datavec;
|
package org.deeplearning4j.datasets.datavec;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.io.Files;
|
import org.nd4j.shade.guava.io.Files;
|
||||||
import org.apache.commons.compress.utils.IOUtils;
|
import org.apache.commons.compress.utils.IOUtils;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.dtypes;
|
package org.deeplearning4j.nn.dtypes;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableSet;
|
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||||
import com.google.common.reflect.ClassPath;
|
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
@ -103,7 +103,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
ImmutableSet<ClassPath.ClassInfo> info;
|
ImmutableSet<ClassPath.ClassInfo> info;
|
||||||
try {
|
try {
|
||||||
//Dependency note: this ClassPath class was added in Guava 14
|
//Dependency note: this ClassPath class was added in Guava 14
|
||||||
info = com.google.common.reflect.ClassPath.from(DTypeTests.class.getClassLoader())
|
info = org.nd4j.shade.guava.reflect.ClassPath.from(DTypeTests.class.getClassLoader())
|
||||||
.getTopLevelClassesRecursive("org.deeplearning4j");
|
.getTopLevelClassesRecursive("org.deeplearning4j");
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
//Should never happen
|
//Should never happen
|
||||||
|
|
|
@ -229,7 +229,9 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
net.fit(in,l);
|
net.fit(in,l);
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
String msg = t.getMessage();
|
String msg = t.getMessage();
|
||||||
assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
|
if(msg == null)
|
||||||
|
t.printStackTrace();
|
||||||
|
assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.plot;
|
package org.deeplearning4j.plot;
|
||||||
|
|
||||||
import com.google.common.util.concurrent.AtomicDouble;
|
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
|
|
|
@ -22,23 +22,24 @@ import org.deeplearning4j.TestUtils;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.conf.BackpropType;
|
import org.deeplearning4j.nn.conf.BackpropType;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||||
import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.impl.*;
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationTanH;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.RmsProp;
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
@ -60,6 +61,9 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomLayer() throws Exception {
|
public void testCustomLayer() throws Exception {
|
||||||
|
//We dropped support for 1.0.0-alpha and earlier custom layers due to the maintenance overhead for a rarely used feature
|
||||||
|
//An upgrade path exists as a workaround - load in beta to beta4 and re-save
|
||||||
|
//All built-in layers can be loaded going back to 0.5.0
|
||||||
|
|
||||||
File f = Resources.asFile("regression_testing/100a/CustomLayerExample_100a.bin");
|
File f = Resources.asFile("regression_testing/100a/CustomLayerExample_100a.bin");
|
||||||
|
|
||||||
|
@ -68,67 +72,8 @@ public class RegressionTest100a extends BaseDL4JTest {
|
||||||
fail("Expected exception");
|
fail("Expected exception");
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
String msg = e.getMessage();
|
String msg = e.getMessage();
|
||||||
assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON"));
|
assertTrue(msg, msg.contains("custom") && msg.contains("1.0.0-beta") && msg.contains("saved again"));
|
||||||
}
|
}
|
||||||
|
|
||||||
NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class);
|
|
||||||
|
|
||||||
MultiLayerNetwork net = MultiLayerNetwork.load(f, true);
|
|
||||||
|
|
||||||
DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer();
|
|
||||||
assertEquals(new ActivationTanH(), l0.getActivationFn());
|
|
||||||
assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0));
|
|
||||||
assertEquals(new RmsProp(0.95), l0.getIUpdater());
|
|
||||||
|
|
||||||
CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer();
|
|
||||||
assertEquals(new ActivationTanH(), l1.getActivationFn());
|
|
||||||
assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction());
|
|
||||||
assertEquals(new RmsProp(0.95), l1.getIUpdater());
|
|
||||||
|
|
||||||
|
|
||||||
INDArray outExp;
|
|
||||||
File f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Output_100a.bin");
|
|
||||||
try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){
|
|
||||||
outExp = Nd4j.read(dis);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray in;
|
|
||||||
File f3 = Resources.asFile("regression_testing/100a/CustomLayerExample_Input_100a.bin");
|
|
||||||
try(DataInputStream dis = new DataInputStream(new FileInputStream(f3))){
|
|
||||||
in = Nd4j.read(dis);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray outAct = net.output(in);
|
|
||||||
|
|
||||||
assertEquals(outExp, outAct);
|
|
||||||
|
|
||||||
|
|
||||||
//Check graph
|
|
||||||
f = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_100a.bin");
|
|
||||||
|
|
||||||
//Deregister custom class:
|
|
||||||
new LegacyLayerDeserializer().getLegacyNamesMap().remove("CustomLayer");
|
|
||||||
|
|
||||||
try {
|
|
||||||
ComputationGraph.load(f, true);
|
|
||||||
fail("Expected exception");
|
|
||||||
} catch (Exception e){
|
|
||||||
String msg = e.getMessage();
|
|
||||||
assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON"));
|
|
||||||
}
|
|
||||||
|
|
||||||
NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class);
|
|
||||||
|
|
||||||
ComputationGraph graph = ComputationGraph.load(f, true);
|
|
||||||
|
|
||||||
f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_Output_100a.bin");
|
|
||||||
try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){
|
|
||||||
outExp = Nd4j.read(dis);
|
|
||||||
}
|
|
||||||
|
|
||||||
outAct = graph.outputSingle(in);
|
|
||||||
|
|
||||||
assertEquals(outExp, outAct);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets.iterator;
|
package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import org.nd4j.shade.guava.annotations.VisibleForTesting;
|
||||||
import com.google.common.collect.Lists;
|
import org.nd4j.shade.guava.collect.Lists;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.datasets.iterator.parallel;
|
package org.deeplearning4j.datasets.iterator.parallel;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import org.nd4j.shade.guava.collect.Lists;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.deeplearning4j.plot;
|
package org.deeplearning4j.plot;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.util.concurrent.AtomicDouble;
|
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.plot;
|
package org.deeplearning4j.plot;
|
||||||
|
|
||||||
import com.google.common.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dimensionalityreduction.PCA;
|
import org.nd4j.linalg.dimensionalityreduction.PCA;
|
||||||
|
|
|
@ -37,6 +37,12 @@
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.google.code.gson</groupId>
|
||||||
|
<artifactId>gson</artifactId>
|
||||||
|
<version>${gson.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
<artifactId>deeplearning4j-nn</artifactId>
|
<artifactId>deeplearning4j-nn</artifactId>
|
||||||
|
|
|
@ -77,26 +77,11 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.google.guava</groupId>
|
|
||||||
<artifactId>guava</artifactId>
|
|
||||||
<version>${guava.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.google.protobuf</groupId>
|
<groupId>com.google.protobuf</groupId>
|
||||||
<artifactId>protobuf-java</artifactId>
|
<artifactId>protobuf-java</artifactId>
|
||||||
<version>${google.protobuf.version}</version>
|
<version>${google.protobuf.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-actor_2.11</artifactId>
|
|
||||||
<version>${akka.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-slf4j_2.11</artifactId>
|
|
||||||
<version>${akka.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>joda-time</groupId>
|
<groupId>joda-time</groupId>
|
||||||
<artifactId>joda-time</artifactId>
|
<artifactId>joda-time</artifactId>
|
||||||
|
@ -213,11 +198,6 @@
|
||||||
<artifactId>play-netty-server_2.11</artifactId>
|
<artifactId>play-netty-server_2.11</artifactId>
|
||||||
<version>${playframework.version}</version>
|
<version>${playframework.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-cluster_2.11</artifactId>
|
|
||||||
<version>${akka.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.server;
|
|
||||||
|
|
||||||
import play.libs.F;
|
|
||||||
import play.mvc.Result;
|
|
||||||
|
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Supplier;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility methods for Routing
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class FunctionUtil {
|
|
||||||
|
|
||||||
|
|
||||||
public static F.Function0<Result> function0(Supplier<Result> supplier) {
|
|
||||||
return supplier::get;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> F.Function<T, Result> function(Function<T, Result> function) {
|
|
||||||
return function::apply;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -33,8 +33,10 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
import org.nd4j.serde.base64.Nd4jBase64;
|
||||||
import org.nd4j.serde.binary.BinarySerde;
|
import org.nd4j.serde.binary.BinarySerde;
|
||||||
|
import play.BuiltInComponents;
|
||||||
import play.Mode;
|
import play.Mode;
|
||||||
import play.libs.Json;
|
import play.libs.Json;
|
||||||
|
import play.routing.Router;
|
||||||
import play.routing.RoutingDsl;
|
import play.routing.RoutingDsl;
|
||||||
import play.server.Server;
|
import play.server.Server;
|
||||||
|
|
||||||
|
@ -149,19 +151,36 @@ public class NearestNeighborsServer {
|
||||||
|
|
||||||
VPTree tree = new VPTree(points, similarityFunction, invert);
|
VPTree tree = new VPTree(points, similarityFunction, invert);
|
||||||
|
|
||||||
RoutingDsl routingDsl = new RoutingDsl();
|
//Set play secret key, if required
|
||||||
|
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||||
|
String crypto = System.getProperty("play.crypto.secret");
|
||||||
|
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
|
||||||
|
byte[] newCrypto = new byte[1024];
|
||||||
|
|
||||||
|
new Random().nextBytes(newCrypto);
|
||||||
|
|
||||||
|
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
||||||
|
System.setProperty("play.crypto.secret", base64);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Router createRouter(VPTree tree, List<String> labels, INDArray points, BuiltInComponents builtInComponents){
|
||||||
|
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
||||||
//return the host information for a given id
|
//return the host information for a given id
|
||||||
routingDsl.POST("/knn").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/knn").routingTo(request -> {
|
||||||
try {
|
try {
|
||||||
NearestNeighborRequest record = Json.fromJson(request().body().asJson(), NearestNeighborRequest.class);
|
NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class);
|
||||||
NearestNeighbor nearestNeighbor =
|
NearestNeighbor nearestNeighbor =
|
||||||
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
||||||
|
|
||||||
if (record == null)
|
if (record == null)
|
||||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
||||||
|
|
||||||
NearestNeighborsResults results =
|
NearestNeighborsResults results =
|
||||||
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
||||||
|
|
||||||
|
|
||||||
return ok(Json.toJson(results));
|
return ok(Json.toJson(results));
|
||||||
|
@ -171,11 +190,11 @@ public class NearestNeighborsServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
routingDsl.POST("/knnnew").routeTo(FunctionUtil.function0((() -> {
|
routingDsl.POST("/knnnew").routingTo(request -> {
|
||||||
try {
|
try {
|
||||||
Base64NDArrayBody record = Json.fromJson(request().body().asJson(), Base64NDArrayBody.class);
|
Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class);
|
||||||
if (record == null)
|
if (record == null)
|
||||||
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed.")));
|
||||||
|
|
||||||
|
@ -216,23 +235,9 @@ public class NearestNeighborsServer {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return internalServerError(e.getMessage());
|
return internalServerError(e.getMessage());
|
||||||
}
|
}
|
||||||
})));
|
});
|
||||||
|
|
||||||
//Set play secret key, if required
|
|
||||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
|
||||||
String crypto = System.getProperty("play.crypto.secret");
|
|
||||||
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
|
|
||||||
byte[] newCrypto = new byte[1024];
|
|
||||||
|
|
||||||
new Random().nextBytes(newCrypto);
|
|
||||||
|
|
||||||
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
|
||||||
System.setProperty("play.crypto.secret", base64);
|
|
||||||
}
|
|
||||||
|
|
||||||
server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
|
|
||||||
|
|
||||||
|
|
||||||
|
return routingDsl.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue