commit
b393d3fdb1
|
@ -72,6 +72,11 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
<version>${jodatime.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- ND4J Shaded Jackson Dependency -->
|
||||
<dependency>
|
||||
|
@ -80,4 +85,13 @@
|
|||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.nd4j.shade.guava.base.Preconditions;
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.math3.distribution.RealDistribution;
|
||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.runner;
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture;
|
||||
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize.runner;
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture;
|
||||
import com.google.common.util.concurrent.ListeningExecutorService;
|
||||
import com.google.common.util.concurrent.MoreExecutors;
|
||||
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
||||
import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService;
|
||||
import org.nd4j.shade.guava.util.concurrent.MoreExecutors;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.arbiter.optimize.api.*;
|
||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
||||
|
|
|
@ -43,13 +43,15 @@ public class JsonMapper {
|
|||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
yamlMapper = new ObjectMapper(new YAMLFactory());
|
||||
mapper.registerModule(new JodaModule());
|
||||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
yamlMapper.registerModule(new JodaModule());
|
||||
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
yamlMapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
}
|
||||
|
||||
private JsonMapper() {}
|
||||
|
|
|
@ -39,6 +39,7 @@ public class YamlMapper {
|
|||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ public class TestJson {
|
|||
om.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
return om;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,13 +38,6 @@
|
|||
<version>${dl4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
@ -64,6 +57,20 @@
|
|||
<artifactId>jackson</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>${gson.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.layers;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.nd4j.shade.guava.base.Preconditions;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
|
|
@ -49,11 +49,14 @@
|
|||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -97,15 +97,16 @@
|
|||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>${guava.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>arbiter-core</artifactId>
|
||||
|
@ -124,13 +125,6 @@
|
|||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
@ -139,9 +133,6 @@
|
|||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
|
||||
|
||||
<build>
|
||||
<extensions>
|
||||
<extension>
|
||||
|
@ -222,5 +213,4 @@
|
|||
</plugins>
|
||||
</pluginManagement>
|
||||
</build>
|
||||
|
||||
</project>
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.ui.misc;
|
|||
|
||||
|
||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||
|
@ -45,12 +46,9 @@ public class JsonMapper {
|
|||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
|
||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
|
||||
mapper.setVisibilityChecker(mapper.getSerializationConfig().getDefaultVisibilityChecker()
|
||||
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
|
||||
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
|
||||
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
|
||||
.withCreatorVisibility(JsonAutoDetect.Visibility.NONE));
|
||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
||||
|
||||
return mapper;
|
||||
}
|
||||
|
|
|
@ -136,6 +136,31 @@
|
|||
|
||||
<pluginManagement>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-enforcer-plugin</artifactId>
|
||||
<version>${maven-enforcer-plugin.version}</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>test</phase>
|
||||
<id>enforce-test-resources</id>
|
||||
<goals>
|
||||
<goal>enforce</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<skip>${skipTestResourceEnforcement}</skip>
|
||||
<rules>
|
||||
<requireActiveProfile>
|
||||
<profiles>test-nd4j-native,test-nd4j-cuda-10.1</profiles>
|
||||
<all>false</all>
|
||||
</requireActiveProfile>
|
||||
</rules>
|
||||
<fail>true</fail>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<artifactId>maven-javadoc-plugin</artifactId>
|
||||
<version>${maven-javadoc-plugin.version}</version>
|
||||
|
@ -287,4 +312,42 @@
|
|||
</plugin>
|
||||
</plugins>
|
||||
</reporting>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
|
||||
set -e
|
||||
|
||||
VALID_VERSIONS=( 2.10 2.11 )
|
||||
SCALA_210_VERSION=$(grep -F -m 1 'scala210.version' pom.xml); SCALA_210_VERSION="${SCALA_210_VERSION#*>}"; SCALA_210_VERSION="${SCALA_210_VERSION%<*}";
|
||||
VALID_VERSIONS=( 2.11 2.12 )
|
||||
SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}";
|
||||
SCALA_212_VERSION=$(grep -F -m 1 'scala212.version' pom.xml); SCALA_212_VERSION="${SCALA_212_VERSION#*>}"; SCALA_212_VERSION="${SCALA_212_VERSION%<*}";
|
||||
|
||||
usage() {
|
||||
echo "Usage: $(basename $0) [-h|--help] <scala version to be used>
|
||||
|
@ -45,19 +45,18 @@ check_scala_version() {
|
|||
exit 1
|
||||
}
|
||||
|
||||
|
||||
check_scala_version "$TO_VERSION"
|
||||
|
||||
if [ $TO_VERSION = "2.11" ]; then
|
||||
FROM_BINARY="_2\.10"
|
||||
FROM_BINARY="_2\.12"
|
||||
TO_BINARY="_2\.11"
|
||||
FROM_VERSION=$SCALA_210_VERSION
|
||||
FROM_VERSION=$SCALA_212_VERSION
|
||||
TO_VERSION=$SCALA_211_VERSION
|
||||
else
|
||||
FROM_BINARY="_2\.11"
|
||||
TO_BINARY="_2\.10"
|
||||
TO_BINARY="_2\.12"
|
||||
FROM_VERSION=$SCALA_211_VERSION
|
||||
TO_VERSION=$SCALA_210_VERSION
|
||||
TO_VERSION=$SCALA_212_VERSION
|
||||
fi
|
||||
|
||||
sed_i() {
|
||||
|
@ -70,35 +69,24 @@ echo "Updating Scala versions in pom.xml files to Scala $1, from $FROM_VERSION t
|
|||
|
||||
BASEDIR=$(dirname $0)
|
||||
|
||||
#Artifact ids, ending with "_2.10" or "_2.11". Spark, spark-mllib, kafka, etc.
|
||||
#Artifact ids, ending with "_2.11" or "_2.12". Spark, spark-mllib, kafka, etc.
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \;
|
||||
|
||||
#Scala versions, like <scala.version>2.10</scala.version>
|
||||
#Scala versions, like <scala.version>2.11</scala.version>
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \;
|
||||
|
||||
#Scala binary versions, like <scala.binary.version>2.10</scala.binary.version>
|
||||
#Scala binary versions, like <scala.binary.version>2.11</scala.binary.version>
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \;
|
||||
|
||||
#Scala versions, like <artifactId>scala-library</artifactId> <version>2.10.6</version>
|
||||
#Scala versions, like <artifactId>scala-library</artifactId> <version>2.11.12</version>
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \;
|
||||
|
||||
#Scala maven plugin, <scalaVersion>2.10</scalaVersion>
|
||||
#Scala maven plugin, <scalaVersion>2.11</scalaVersion>
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \;
|
||||
|
||||
#Edge case for Korean NLP artifact not following conventions: https://github.com/deeplearning4j/deeplearning4j/issues/6306
|
||||
#https://github.com/deeplearning4j/deeplearning4j/issues/6306
|
||||
if [[ $TO_VERSION == 2.11* ]]; then
|
||||
sed_i 's/<artifactId>korean-text-scala-2.10<\/artifactId>/<artifactId>korean-text<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
||||
sed_i 's/<version>4.2.0<\/version>/<version>4.4<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
||||
else
|
||||
sed_i 's/<artifactId>korean-text<\/artifactId>/<artifactId>korean-text-scala-2.10<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
||||
sed_i 's/<version>4.4<\/version>/<version>4.2.0<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
|
||||
fi
|
||||
|
||||
|
||||
echo "Done updating Scala versions.";
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
################################################################################
|
||||
# Copyright (c) 2015-2018 Skymind, Inc.
|
||||
#
|
||||
# This program and the accompanying materials are made available under the
|
||||
# terms of the Apache License, Version 2.0 which is available at
|
||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
################################################################################
|
||||
|
||||
# This shell script is adapted from Apache Flink (in turn, adapted from Apache Spark) some modifications.
|
||||
|
||||
set -e
|
||||
|
||||
VALID_VERSIONS=( 1 2 )
|
||||
SPARK_2_VERSION="2\.1\.0"
|
||||
SPARK_1_VERSION="1\.6\.3"
|
||||
|
||||
usage() {
|
||||
echo "Usage: $(basename $0) [-h|--help] <spark version to be used>
|
||||
where :
|
||||
-h| --help Display this help text
|
||||
valid spark version values : ${VALID_VERSIONS[*]}
|
||||
" 1>&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then
|
||||
usage
|
||||
fi
|
||||
|
||||
TO_VERSION=$1
|
||||
|
||||
check_spark_version() {
|
||||
for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done
|
||||
echo "Invalid Spark version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
|
||||
check_spark_version "$TO_VERSION"
|
||||
|
||||
if [ $TO_VERSION = "2" ]; then
|
||||
FROM_BINARY="1"
|
||||
TO_BINARY="2"
|
||||
FROM_VERSION=$SPARK_1_VERSION
|
||||
TO_VERSION=$SPARK_2_VERSION
|
||||
else
|
||||
FROM_BINARY="2"
|
||||
TO_BINARY="1"
|
||||
FROM_VERSION=$SPARK_2_VERSION
|
||||
TO_VERSION=$SPARK_1_VERSION
|
||||
fi
|
||||
|
||||
sed_i() {
|
||||
sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2"
|
||||
}
|
||||
|
||||
export -f sed_i
|
||||
|
||||
echo "Updating Spark versions in pom.xml files to Spark $1";
|
||||
|
||||
BASEDIR=$(dirname $0)
|
||||
|
||||
# <spark.major.version>1</spark.major.version>
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(spark.major.version>\)'$FROM_BINARY'<\/spark.major.version>/\1'$TO_BINARY'<\/spark.major.version>/g' {}" \;
|
||||
|
||||
# <spark.version>1.6.3</spark.version>
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(spark.version>\)'$FROM_VERSION'<\/spark.version>/\1'$TO_VERSION'<\/spark.version>/g' {}" \;
|
||||
|
||||
#Spark versions, like <version>xxx_spark_2xxx</version> OR <datavec.spark.version>xxx_spark_2xxx</datavec.spark.version>
|
||||
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
|
||||
-exec bash -c "sed_i 's/\(version>.*_spark_\)'$FROM_BINARY'\(.*\)version>/\1'$TO_BINARY'\2version>/g' {}" \;
|
||||
|
||||
echo "Done updating Spark versions.";
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
<artifactId>datavec-api</artifactId>
|
||||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
|
@ -98,13 +97,6 @@
|
|||
<version>${stream.analytics.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- csv parser, same dep used by spark -->
|
||||
<dependency>
|
||||
<groupId>net.sf.opencsv</groupId>
|
||||
|
@ -125,7 +117,6 @@
|
|||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
package org.datavec.api.transform;
|
||||
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -27,8 +26,7 @@ import java.util.List;
|
|||
/**A Transform converts an example to another example, or a sequence to another sequence
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.TransformHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface Transform extends Serializable, ColumnOp {
|
||||
|
||||
/**
|
||||
|
|
|
@ -67,6 +67,7 @@ import org.joda.time.DateTimeZone;
|
|||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
|
@ -417,6 +418,16 @@ public class TransformProcess implements Serializable {
|
|||
public static TransformProcess fromJson(String json) {
|
||||
try {
|
||||
return JsonMappers.getMapper().readValue(json, TransformProcess.class);
|
||||
} catch (InvalidTypeIdException e){
|
||||
if(e.getMessage().contains("@class")){
|
||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||
try{
|
||||
return JsonMappers.getLegacyMapper().readValue(json, TransformProcess.class);
|
||||
} catch (IOException e2){
|
||||
throw new RuntimeException(e2);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
} catch (IOException e) {
|
||||
//TODO proper exception message
|
||||
throw new RuntimeException(e);
|
||||
|
|
|
@ -23,12 +23,14 @@ import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
|
|||
import org.datavec.api.transform.metadata.CategoricalMetaData;
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.JsonMappers;
|
||||
import org.datavec.api.transform.serde.JsonSerializer;
|
||||
import org.datavec.api.transform.serde.YamlSerializer;
|
||||
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||
import org.nd4j.shade.jackson.databind.node.ArrayNode;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -116,6 +118,16 @@ public class DataAnalysis implements Serializable {
|
|||
public static DataAnalysis fromJson(String json) {
|
||||
try{
|
||||
return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class);
|
||||
} catch (InvalidTypeIdException e){
|
||||
if(e.getMessage().contains("@class")){
|
||||
try{
|
||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||
return JsonMappers.getLegacyMapper().readValue(json, DataAnalysis.class);
|
||||
} catch (IOException e2){
|
||||
throw new RuntimeException(e2);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
} catch (Exception e){
|
||||
//Legacy format
|
||||
ObjectMapper om = new JsonSerializer().getObjectMapper();
|
||||
|
|
|
@ -21,9 +21,10 @@ import lombok.EqualsAndHashCode;
|
|||
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
|
||||
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.JsonMappers;
|
||||
import org.datavec.api.transform.serde.JsonSerializer;
|
||||
import org.datavec.api.transform.serde.YamlSerializer;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
@ -50,6 +51,16 @@ public class SequenceDataAnalysis extends DataAnalysis {
|
|||
public static SequenceDataAnalysis fromJson(String json){
|
||||
try{
|
||||
return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class);
|
||||
} catch (InvalidTypeIdException e){
|
||||
if(e.getMessage().contains("@class")){
|
||||
try{
|
||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||
return JsonMappers.getLegacyMapper().readValue(json, SequenceDataAnalysis.class);
|
||||
} catch (IOException e2){
|
||||
throw new RuntimeException(e2);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
} catch (IOException e){
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.datavec.api.transform.analysis.columns;
|
||||
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
|
@ -27,8 +26,7 @@ import java.io.Serializable;
|
|||
* Interface for column analysis
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.ColumnAnalysisHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ColumnAnalysis extends Serializable {
|
||||
|
||||
long getCountTotal();
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.datavec.api.transform.condition;
|
|||
|
||||
import org.datavec.api.transform.ColumnOp;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -35,8 +34,7 @@ import java.util.List;
|
|||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.ConditionHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface Condition extends Serializable, ColumnOp {
|
||||
|
||||
/**
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.datavec.api.transform.filter;
|
|||
|
||||
import org.datavec.api.transform.ColumnOp;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -33,8 +32,7 @@ import java.util.List;
|
|||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.FilterHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface Filter extends Serializable, ColumnOp {
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.datavec.api.transform.metadata;
|
||||
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -32,8 +31,7 @@ import java.io.Serializable;
|
|||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.ColumnMetaDataHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ColumnMetaData extends Serializable, Cloneable {
|
||||
|
||||
/**
|
||||
|
|
|
@ -23,8 +23,8 @@ import org.datavec.api.writable.Writable;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
import static com.google.common.base.Preconditions.checkArgument;
|
||||
import static com.google.common.base.Preconditions.checkNotNull;
|
||||
import static org.nd4j.shade.guava.base.Preconditions.checkArgument;
|
||||
import static org.nd4j.shade.guava.base.Preconditions.checkNotNull;
|
||||
|
||||
/**
|
||||
* A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition},
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.datavec.api.transform.metadata.ColumnMetaData;
|
|||
import org.datavec.api.transform.metadata.LongMetaData;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.schema.SequenceSchema;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.comparator.WritableComparator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
|
@ -50,8 +49,7 @@ import java.util.List;
|
|||
@EqualsAndHashCode(exclude = {"inputSchema"})
|
||||
@JsonIgnoreProperties({"inputSchema"})
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.CalculateSortedRankHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public class CalculateSortedRank implements Serializable, ColumnOp {
|
||||
|
||||
private final String newColumnName;
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
|
|||
import org.datavec.api.transform.ColumnType;
|
||||
import org.datavec.api.transform.metadata.*;
|
||||
import org.datavec.api.transform.serde.JsonMappers;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.nd4j.shade.jackson.annotation.*;
|
||||
|
@ -29,9 +28,11 @@ import org.nd4j.shade.jackson.core.JsonFactory;
|
|||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
|
||||
|
@ -48,8 +49,7 @@ import java.util.*;
|
|||
*/
|
||||
@JsonIgnoreProperties({"columnNames", "columnNamesIndex"})
|
||||
@EqualsAndHashCode
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.SchemaHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@Data
|
||||
public class Schema implements Serializable {
|
||||
|
||||
|
@ -358,6 +358,16 @@ public class Schema implements Serializable {
|
|||
public static Schema fromJson(String json) {
|
||||
try{
|
||||
return JsonMappers.getMapper().readValue(json, Schema.class);
|
||||
} catch (InvalidTypeIdException e){
|
||||
if(e.getMessage().contains("@class")){
|
||||
try{
|
||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
||||
return JsonMappers.getLegacyMapper().readValue(json, Schema.class);
|
||||
} catch (IOException e2){
|
||||
throw new RuntimeException(e2);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException(e);
|
||||
} catch (Exception e){
|
||||
//TODO better exceptions
|
||||
throw new RuntimeException(e);
|
||||
|
@ -379,21 +389,6 @@ public class Schema implements Serializable {
|
|||
}
|
||||
}
|
||||
|
||||
private static Schema fromJacksonString(String str, JsonFactory factory) {
|
||||
ObjectMapper om = new ObjectMapper(factory);
|
||||
om.registerModule(new JodaModule());
|
||||
om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||
om.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
try {
|
||||
return om.readValue(str, Schema.class);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
List<ColumnMetaData> columnMetaData = new ArrayList<>();
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.datavec.api.transform.sequence;
|
||||
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -30,8 +29,7 @@ import java.util.List;
|
|||
* Compare the time steps of a sequence
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.SequenceComparatorHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface SequenceComparator extends Comparator<List<Writable>>, Serializable {
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.datavec.api.transform.sequence;
|
||||
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -32,8 +31,7 @@ import java.util.List;
|
|||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.SequenceSplitHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface SequenceSplit extends Serializable {
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.datavec.api.transform.sequence.window;
|
||||
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -36,8 +35,7 @@ import java.util.List;
|
|||
* @author Alex Black
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.WindowFunctionHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface WindowFunction extends Serializable {
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,44 +16,17 @@
|
|||
|
||||
package org.datavec.api.transform.serde;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.datavec.api.io.WritableComparator;
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
|
||||
import org.datavec.api.transform.condition.column.ColumnCondition;
|
||||
import org.datavec.api.transform.filter.Filter;
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.datavec.api.transform.rank.CalculateSortedRank;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.sequence.SequenceComparator;
|
||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
||||
import org.datavec.api.transform.sequence.window.WindowFunction;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.serde.json.LegacyIActivationDeserializer;
|
||||
import org.nd4j.serde.json.LegacyILossFunctionDeserializer;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyJsonFormat;
|
||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
||||
import org.nd4j.shade.jackson.databind.*;
|
||||
import org.nd4j.shade.jackson.databind.cfg.MapperConfig;
|
||||
import org.nd4j.shade.jackson.databind.introspect.Annotated;
|
||||
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass;
|
||||
import org.nd4j.shade.jackson.databind.introspect.AnnotationMap;
|
||||
import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector;
|
||||
import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
||||
|
||||
import java.lang.annotation.Annotation;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* JSON mappers for deserializing neural net configurations, etc.
|
||||
*
|
||||
|
@ -62,38 +35,9 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||
@Slf4j
|
||||
public class JsonMappers {
|
||||
|
||||
/**
|
||||
* This system property is provided as an alternative to {@link #registerLegacyCustomClassesForJSON(Class[])}
|
||||
* Classes can be specified in comma-separated format
|
||||
*/
|
||||
public static String CUSTOM_REGISTRATION_PROPERTY = "org.datavec.config.custom.legacyclasses";
|
||||
|
||||
static {
|
||||
String p = System.getProperty(CUSTOM_REGISTRATION_PROPERTY);
|
||||
if(p != null && !p.isEmpty()){
|
||||
String[] split = p.split(",");
|
||||
List<Class<?>> list = new ArrayList<>();
|
||||
for(String s : split){
|
||||
try{
|
||||
Class<?> c = Class.forName(s);
|
||||
list.add(c);
|
||||
} catch (Throwable t){
|
||||
log.warn("Error parsing {} system property: class \"{}\" could not be loaded",CUSTOM_REGISTRATION_PROPERTY, s, t);
|
||||
}
|
||||
}
|
||||
|
||||
if(list.size() > 0){
|
||||
try {
|
||||
registerLegacyCustomClassesForJSONList(list);
|
||||
} catch (Throwable t){
|
||||
log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",CUSTOM_REGISTRATION_PROPERTY, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static ObjectMapper jsonMapper;
|
||||
private static ObjectMapper yamlMapper;
|
||||
private static ObjectMapper legacyMapper; //For 1.0.0-alpha and earlier TransformProcess etc
|
||||
|
||||
static {
|
||||
jsonMapper = new ObjectMapper();
|
||||
|
@ -102,117 +46,12 @@ public class JsonMappers {
|
|||
configureMapper(yamlMapper);
|
||||
}
|
||||
|
||||
private static Map<Class, ObjectMapper> legacyMappers = new ConcurrentHashMap<>();
|
||||
|
||||
|
||||
/**
|
||||
* Register a set of classes (Transform, Filter, etc) for JSON deserialization.<br>
|
||||
* <br>
|
||||
* This is required ONLY when BOTH of the following conditions are met:<br>
|
||||
* 1. You want to load a serialized TransformProcess, saved in 1.0.0-alpha or before, AND<br>
|
||||
* 2. The serialized TransformProcess has a custom Transform, Filter, etc (i.e., one not defined in DL4J)<br>
|
||||
* <br>
|
||||
* By passing the classes of these custom classes here, DataVec should be able to deserialize them, in spite of the JSON
|
||||
* format change between versions.
|
||||
*
|
||||
* @param classes Classes to register
|
||||
*/
|
||||
public static void registerLegacyCustomClassesForJSON(Class<?>... classes) {
|
||||
registerLegacyCustomClassesForJSONList(Arrays.<Class<?>>asList(classes));
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #registerLegacyCustomClassesForJSON(Class[])
|
||||
*/
|
||||
public static void registerLegacyCustomClassesForJSONList(List<Class<?>> classes){
|
||||
//Default names (i.e., old format for custom JSON format)
|
||||
List<Pair<String,Class>> list = new ArrayList<>();
|
||||
for(Class<?> c : classes){
|
||||
list.add(new Pair<String,Class>(c.getSimpleName(), c));
|
||||
public static synchronized ObjectMapper getLegacyMapper(){
|
||||
if(legacyMapper == null){
|
||||
legacyMapper = LegacyJsonFormat.legacyMapper();
|
||||
configureMapper(legacyMapper);
|
||||
}
|
||||
registerLegacyCustomClassesForJSON(list);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set of classes that can be registered for legacy deserialization.
|
||||
*/
|
||||
private static List<Class<?>> REGISTERABLE_CUSTOM_CLASSES = (List<Class<?>>) Arrays.<Class<?>>asList(
|
||||
Transform.class,
|
||||
ColumnAnalysis.class,
|
||||
ColumnCondition.class,
|
||||
Filter.class,
|
||||
ColumnMetaData.class,
|
||||
CalculateSortedRank.class,
|
||||
Schema.class,
|
||||
SequenceComparator.class,
|
||||
SequenceSplit.class,
|
||||
WindowFunction.class,
|
||||
Writable.class,
|
||||
WritableComparator.class
|
||||
);
|
||||
|
||||
/**
|
||||
* Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
|
||||
* ONLY) for JSON deserialization, with custom names.<br>
|
||||
* Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])}
|
||||
* but is added in case it is required in non-standard circumstances.
|
||||
*/
|
||||
public static void registerLegacyCustomClassesForJSON(List<Pair<String,Class>> classes){
|
||||
for(Pair<String,Class> p : classes){
|
||||
String s = p.getFirst();
|
||||
Class c = p.getRight();
|
||||
//Check if it's a valid class to register...
|
||||
boolean found = false;
|
||||
for( Class<?> c2 : REGISTERABLE_CUSTOM_CLASSES){
|
||||
if(c2.isAssignableFrom(c)){
|
||||
Map<String,String> map = LegacyMappingHelper.legacyMappingForClass(c2);
|
||||
map.put(p.getFirst(), p.getSecond().getName());
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
|
||||
if(!found){
|
||||
throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " +
|
||||
c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get the legacy JSON mapper for the specified class.<br>
|
||||
*
|
||||
* <b>NOTE</b>: This is intended for internal backward-compatibility use.
|
||||
*
|
||||
* Note to developers: The following JSON mappers are for handling legacy format JSON.
|
||||
* Note that after 1.0.0-alpha, the JSON subtype format for Transforms, Filters, Conditions etc were changed from
|
||||
* a wrapper object, to an "@class" field. However, to not break all saved transforms networks, these mappers are
|
||||
* part of the solution.<br>
|
||||
* <br>
|
||||
* How legacy loading works (same pattern for all types - Transform, Filter, Condition etc)<br>
|
||||
* 1. Transforms etc JSON that has a "@class" field are deserialized as normal<br>
|
||||
* 2. Transforms JSON that don't have such a field are mapped (via Layer @JsonTypeInfo) to LegacyMappingHelper.TransformHelper<br>
|
||||
* 3. LegacyMappingHelper.TransformHelper has a @JsonDeserialize annotation - we use LegacyMappingHelper.LegacyTransformDeserializer to handle it<br>
|
||||
* 4. LegacyTransformDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names
|
||||
* 5. BaseLegacyDeserializer (that LegacyTransformDeserializer extends) does a lookup and handles the deserialization
|
||||
*
|
||||
* Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format,
|
||||
* as it'll fail due to not having the expected "@class" annotation.
|
||||
* Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified
|
||||
* class anyway. The ignoring is done via an annotation introspector, defined below in this class.
|
||||
* However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of
|
||||
* all types - if we did, then any nested types would fail (i.e., an Condition in a Transform - the Transform couldn't
|
||||
* be deserialized correctly, as the annotation would be ignored).
|
||||
*
|
||||
*/
|
||||
public static synchronized ObjectMapper getLegacyMapperFor(@NonNull Class<?> clazz){
|
||||
if(!legacyMappers.containsKey(clazz)){
|
||||
ObjectMapper m = new ObjectMapper();
|
||||
configureMapper(m);
|
||||
m.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.<Class>singletonList(clazz)));
|
||||
legacyMappers.put(clazz, m);
|
||||
}
|
||||
return legacyMappers.get(clazz);
|
||||
return legacyMapper;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -237,61 +76,7 @@ public class JsonMappers {
|
|||
ret.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
||||
ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
||||
ret.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); //Need this otherwise JsonProperty annotations on constructors won't be seen
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc.
|
||||
* This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring
|
||||
* a set of JsonTypeInfo annotations
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector {
|
||||
|
||||
private List<Class> classList;
|
||||
|
||||
@Override
|
||||
protected TypeResolverBuilder<?> _findTypeResolver(MapperConfig<?> config, Annotated ann, JavaType baseType) {
|
||||
if(ann instanceof AnnotatedClass){
|
||||
AnnotatedClass c = (AnnotatedClass)ann;
|
||||
Class<?> annClass = c.getAnnotated();
|
||||
|
||||
boolean isAssignable = false;
|
||||
for(Class c2 : classList){
|
||||
if(c2.isAssignableFrom(annClass)){
|
||||
isAssignable = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if( isAssignable ){
|
||||
AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations();
|
||||
if(annotations == null || annotations.annotations() == null){
|
||||
//Probably not necessary - but here for safety
|
||||
return super._findTypeResolver(config, ann, baseType);
|
||||
}
|
||||
|
||||
AnnotationMap newMap = null;
|
||||
for(Annotation a : annotations.annotations()){
|
||||
Class<?> annType = a.annotationType();
|
||||
if(annType == JsonTypeInfo.class){
|
||||
//Ignore the JsonTypeInfo annotation on the Layer class
|
||||
continue;
|
||||
}
|
||||
if(newMap == null){
|
||||
newMap = new AnnotationMap();
|
||||
}
|
||||
newMap.add(a);
|
||||
}
|
||||
if(newMap == null)
|
||||
return null;
|
||||
|
||||
//Pass the remaining annotations (if any) to the original introspector
|
||||
AnnotatedClass ann2 = c.withAnnotations(newMap);
|
||||
return super._findTypeResolver(config, ann2, baseType);
|
||||
}
|
||||
}
|
||||
return super._findTypeResolver(config, ann, baseType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.api.transform.serde.legacy;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import org.datavec.api.transform.serde.JsonMappers;
|
||||
import org.nd4j.serde.json.BaseLegacyDeserializer;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Data
|
||||
public class GenericLegacyDeserializer<T> extends BaseLegacyDeserializer<T> {
|
||||
|
||||
@Getter
|
||||
protected final Class<T> deserializedType;
|
||||
@Getter
|
||||
protected final Map<String,String> legacyNamesMap;
|
||||
|
||||
@Override
|
||||
public ObjectMapper getLegacyJsonMapper() {
|
||||
return JsonMappers.getLegacyMapperFor(getDeserializedType());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,267 @@
|
|||
package org.datavec.api.transform.serde.legacy;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.transform.analysis.columns.*;
|
||||
import org.datavec.api.transform.condition.BooleanCondition;
|
||||
import org.datavec.api.transform.condition.Condition;
|
||||
import org.datavec.api.transform.condition.column.*;
|
||||
import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
|
||||
import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
||||
import org.datavec.api.transform.filter.ConditionFilter;
|
||||
import org.datavec.api.transform.filter.Filter;
|
||||
import org.datavec.api.transform.filter.FilterInvalidValues;
|
||||
import org.datavec.api.transform.filter.InvalidNumColumns;
|
||||
import org.datavec.api.transform.metadata.*;
|
||||
import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
|
||||
import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
|
||||
import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
|
||||
import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
|
||||
import org.datavec.api.transform.rank.CalculateSortedRank;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.schema.SequenceSchema;
|
||||
import org.datavec.api.transform.sequence.ReduceSequenceTransform;
|
||||
import org.datavec.api.transform.sequence.SequenceComparator;
|
||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
||||
import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
|
||||
import org.datavec.api.transform.sequence.comparator.StringComparator;
|
||||
import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
|
||||
import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
|
||||
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
|
||||
import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
|
||||
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
|
||||
import org.datavec.api.transform.sequence.window.TimeWindowFunction;
|
||||
import org.datavec.api.transform.sequence.window.WindowFunction;
|
||||
import org.datavec.api.transform.stringreduce.IStringReducer;
|
||||
import org.datavec.api.transform.stringreduce.StringReducer;
|
||||
import org.datavec.api.transform.transform.categorical.*;
|
||||
import org.datavec.api.transform.transform.column.*;
|
||||
import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
|
||||
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
|
||||
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
|
||||
import org.datavec.api.transform.transform.doubletransform.*;
|
||||
import org.datavec.api.transform.transform.integer.*;
|
||||
import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
|
||||
import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
|
||||
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
|
||||
import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
|
||||
import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
|
||||
import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
|
||||
import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
|
||||
import org.datavec.api.transform.transform.string.*;
|
||||
import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
|
||||
import org.datavec.api.transform.transform.time.StringToTimeTransform;
|
||||
import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.api.writable.comparator.*;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
/**
|
||||
* This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override
|
||||
* the existing annotations.
|
||||
* In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field".
|
||||
* We use these mixins to allow us to still load the old format
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class LegacyJsonFormat {
|
||||
|
||||
private LegacyJsonFormat(){ }
|
||||
|
||||
/**
|
||||
* Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before
|
||||
* @return Object mapper
|
||||
*/
|
||||
public static ObjectMapper legacyMapper(){
|
||||
ObjectMapper om = new ObjectMapper();
|
||||
om.addMixIn(Schema.class, SchemaMixin.class);
|
||||
om.addMixIn(ColumnMetaData.class, ColumnMetaDataMixin.class);
|
||||
om.addMixIn(Transform.class, TransformMixin.class);
|
||||
om.addMixIn(Condition.class, ConditionMixin.class);
|
||||
om.addMixIn(Writable.class, WritableMixin.class);
|
||||
om.addMixIn(Filter.class, FilterMixin.class);
|
||||
om.addMixIn(SequenceComparator.class, SequenceComparatorMixin.class);
|
||||
om.addMixIn(SequenceSplit.class, SequenceSplitMixin.class);
|
||||
om.addMixIn(WindowFunction.class, WindowFunctionMixin.class);
|
||||
om.addMixIn(CalculateSortedRank.class, CalculateSortedRankMixin.class);
|
||||
om.addMixIn(WritableComparator.class, WritableComparatorMixin.class);
|
||||
om.addMixIn(ColumnAnalysis.class, ColumnAnalysisMixin.class);
|
||||
om.addMixIn(IStringReducer.class, IStringReducerMixin.class);
|
||||
return om;
|
||||
}
|
||||
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes({@JsonSubTypes.Type(value = Schema.class, name = "Schema"),
|
||||
@JsonSubTypes.Type(value = SequenceSchema.class, name = "SequenceSchema")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class SchemaMixin { }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes({@JsonSubTypes.Type(value = BinaryMetaData.class, name = "Binary"),
|
||||
@JsonSubTypes.Type(value = BooleanMetaData.class, name = "Boloean"),
|
||||
@JsonSubTypes.Type(value = CategoricalMetaData.class, name = "Categorical"),
|
||||
@JsonSubTypes.Type(value = DoubleMetaData.class, name = "Double"),
|
||||
@JsonSubTypes.Type(value = FloatMetaData.class, name = "Float"),
|
||||
@JsonSubTypes.Type(value = IntegerMetaData.class, name = "Integer"),
|
||||
@JsonSubTypes.Type(value = LongMetaData.class, name = "Long"),
|
||||
@JsonSubTypes.Type(value = NDArrayMetaData.class, name = "NDArray"),
|
||||
@JsonSubTypes.Type(value = StringMetaData.class, name = "String"),
|
||||
@JsonSubTypes.Type(value = TimeMetaData.class, name = "Time")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class ColumnMetaDataMixin { }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank"),
|
||||
@JsonSubTypes.Type(value = CategoricalToIntegerTransform.class, name = "CategoricalToIntegerTransform"),
|
||||
@JsonSubTypes.Type(value = CategoricalToOneHotTransform.class, name = "CategoricalToOneHotTransform"),
|
||||
@JsonSubTypes.Type(value = IntegerToCategoricalTransform.class, name = "IntegerToCategoricalTransform"),
|
||||
@JsonSubTypes.Type(value = StringToCategoricalTransform.class, name = "StringToCategoricalTransform"),
|
||||
@JsonSubTypes.Type(value = DuplicateColumnsTransform.class, name = "DuplicateColumnsTransform"),
|
||||
@JsonSubTypes.Type(value = RemoveColumnsTransform.class, name = "RemoveColumnsTransform"),
|
||||
@JsonSubTypes.Type(value = RenameColumnsTransform.class, name = "RenameColumnsTransform"),
|
||||
@JsonSubTypes.Type(value = ReorderColumnsTransform.class, name = "ReorderColumnsTransform"),
|
||||
@JsonSubTypes.Type(value = ConditionalCopyValueTransform.class, name = "ConditionalCopyValueTransform"),
|
||||
@JsonSubTypes.Type(value = ConditionalReplaceValueTransform.class, name = "ConditionalReplaceValueTransform"),
|
||||
@JsonSubTypes.Type(value = ConditionalReplaceValueTransformWithDefault.class, name = "ConditionalReplaceValueTransformWithDefault"),
|
||||
@JsonSubTypes.Type(value = DoubleColumnsMathOpTransform.class, name = "DoubleColumnsMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = DoubleMathOpTransform.class, name = "DoubleMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = Log2Normalizer.class, name = "Log2Normalizer"),
|
||||
@JsonSubTypes.Type(value = MinMaxNormalizer.class, name = "MinMaxNormalizer"),
|
||||
@JsonSubTypes.Type(value = StandardizeNormalizer.class, name = "StandardizeNormalizer"),
|
||||
@JsonSubTypes.Type(value = SubtractMeanNormalizer.class, name = "SubtractMeanNormalizer"),
|
||||
@JsonSubTypes.Type(value = IntegerColumnsMathOpTransform.class, name = "IntegerColumnsMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = IntegerMathOpTransform.class, name = "IntegerMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = ReplaceEmptyIntegerWithValueTransform.class, name = "ReplaceEmptyIntegerWithValueTransform"),
|
||||
@JsonSubTypes.Type(value = ReplaceInvalidWithIntegerTransform.class, name = "ReplaceInvalidWithIntegerTransform"),
|
||||
@JsonSubTypes.Type(value = LongColumnsMathOpTransform.class, name = "LongColumnsMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = LongMathOpTransform.class, name = "LongMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = MapAllStringsExceptListTransform.class, name = "MapAllStringsExceptListTransform"),
|
||||
@JsonSubTypes.Type(value = RemoveWhiteSpaceTransform.class, name = "RemoveWhiteSpaceTransform"),
|
||||
@JsonSubTypes.Type(value = ReplaceEmptyStringTransform.class, name = "ReplaceEmptyStringTransform"),
|
||||
@JsonSubTypes.Type(value = ReplaceStringTransform.class, name = "ReplaceStringTransform"),
|
||||
@JsonSubTypes.Type(value = StringListToCategoricalSetTransform.class, name = "StringListToCategoricalSetTransform"),
|
||||
@JsonSubTypes.Type(value = StringMapTransform.class, name = "StringMapTransform"),
|
||||
@JsonSubTypes.Type(value = DeriveColumnsFromTimeTransform.class, name = "DeriveColumnsFromTimeTransform"),
|
||||
@JsonSubTypes.Type(value = StringToTimeTransform.class, name = "StringToTimeTransform"),
|
||||
@JsonSubTypes.Type(value = TimeMathOpTransform.class, name = "TimeMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = ReduceSequenceByWindowTransform.class, name = "ReduceSequenceByWindowTransform"),
|
||||
@JsonSubTypes.Type(value = DoubleMathFunctionTransform.class, name = "DoubleMathFunctionTransform"),
|
||||
@JsonSubTypes.Type(value = AddConstantColumnTransform.class, name = "AddConstantColumnTransform"),
|
||||
@JsonSubTypes.Type(value = RemoveAllColumnsExceptForTransform.class, name = "RemoveAllColumnsExceptForTransform"),
|
||||
@JsonSubTypes.Type(value = ParseDoubleTransform.class, name = "ParseDoubleTransform"),
|
||||
@JsonSubTypes.Type(value = ConvertToString.class, name = "ConvertToStringTransform"),
|
||||
@JsonSubTypes.Type(value = AppendStringColumnTransform.class, name = "AppendStringColumnTransform"),
|
||||
@JsonSubTypes.Type(value = SequenceDifferenceTransform.class, name = "SequenceDifferenceTransform"),
|
||||
@JsonSubTypes.Type(value = ReduceSequenceTransform.class, name = "ReduceSequenceTransform"),
|
||||
@JsonSubTypes.Type(value = SequenceMovingWindowReduceTransform.class, name = "SequenceMovingWindowReduceTransform"),
|
||||
@JsonSubTypes.Type(value = IntegerToOneHotTransform.class, name = "IntegerToOneHotTransform"),
|
||||
@JsonSubTypes.Type(value = SequenceTrimTransform.class, name = "SequenceTrimTransform"),
|
||||
@JsonSubTypes.Type(value = SequenceOffsetTransform.class, name = "SequenceOffsetTransform"),
|
||||
@JsonSubTypes.Type(value = NDArrayColumnsMathOpTransform.class, name = "NDArrayColumnsMathOpTransform"),
|
||||
@JsonSubTypes.Type(value = NDArrayDistanceTransform.class, name = "NDArrayDistanceTransform"),
|
||||
@JsonSubTypes.Type(value = NDArrayMathFunctionTransform.class, name = "NDArrayMathFunctionTransform"),
|
||||
@JsonSubTypes.Type(value = NDArrayScalarOpTransform.class, name = "NDArrayScalarOpTransform"),
|
||||
@JsonSubTypes.Type(value = ChangeCaseStringTransform.class, name = "ChangeCaseStringTransform"),
|
||||
@JsonSubTypes.Type(value = ConcatenateStringColumns.class, name = "ConcatenateStringColumns"),
|
||||
@JsonSubTypes.Type(value = StringListToCountsNDArrayTransform.class, name = "StringListToCountsNDArrayTransform"),
|
||||
@JsonSubTypes.Type(value = StringListToIndicesNDArrayTransform.class, name = "StringListToIndicesNDArrayTransform"),
|
||||
@JsonSubTypes.Type(value = PivotTransform.class, name = "PivotTransform"),
|
||||
@JsonSubTypes.Type(value = TextToCharacterIndexTransform.class, name = "TextToCharacterIndexTransform")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class TransformMixin { }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = TrivialColumnCondition.class, name = "TrivialColumnCondition"),
|
||||
@JsonSubTypes.Type(value = CategoricalColumnCondition.class, name = "CategoricalColumnCondition"),
|
||||
@JsonSubTypes.Type(value = DoubleColumnCondition.class, name = "DoubleColumnCondition"),
|
||||
@JsonSubTypes.Type(value = IntegerColumnCondition.class, name = "IntegerColumnCondition"),
|
||||
@JsonSubTypes.Type(value = LongColumnCondition.class, name = "LongColumnCondition"),
|
||||
@JsonSubTypes.Type(value = NullWritableColumnCondition.class, name = "NullWritableColumnCondition"),
|
||||
@JsonSubTypes.Type(value = StringColumnCondition.class, name = "StringColumnCondition"),
|
||||
@JsonSubTypes.Type(value = TimeColumnCondition.class, name = "TimeColumnCondition"),
|
||||
@JsonSubTypes.Type(value = StringRegexColumnCondition.class, name = "StringRegexColumnCondition"),
|
||||
@JsonSubTypes.Type(value = BooleanCondition.class, name = "BooleanCondition"),
|
||||
@JsonSubTypes.Type(value = NaNColumnCondition.class, name = "NaNColumnCondition"),
|
||||
@JsonSubTypes.Type(value = InfiniteColumnCondition.class, name = "InfiniteColumnCondition"),
|
||||
@JsonSubTypes.Type(value = SequenceLengthCondition.class, name = "SequenceLengthCondition")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class ConditionMixin { }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = ArrayWritable.class, name = "ArrayWritable"),
|
||||
@JsonSubTypes.Type(value = BooleanWritable.class, name = "BooleanWritable"),
|
||||
@JsonSubTypes.Type(value = ByteWritable.class, name = "ByteWritable"),
|
||||
@JsonSubTypes.Type(value = DoubleWritable.class, name = "DoubleWritable"),
|
||||
@JsonSubTypes.Type(value = FloatWritable.class, name = "FloatWritable"),
|
||||
@JsonSubTypes.Type(value = IntWritable.class, name = "IntWritable"),
|
||||
@JsonSubTypes.Type(value = LongWritable.class, name = "LongWritable"),
|
||||
@JsonSubTypes.Type(value = NullWritable.class, name = "NullWritable"),
|
||||
@JsonSubTypes.Type(value = Text.class, name = "Text"),
|
||||
@JsonSubTypes.Type(value = BytesWritable.class, name = "BytesWritable")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class WritableMixin { }
|
||||
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = ConditionFilter.class, name = "ConditionFilter"),
|
||||
@JsonSubTypes.Type(value = FilterInvalidValues.class, name = "FilterInvalidValues"),
|
||||
@JsonSubTypes.Type(value = InvalidNumColumns.class, name = "InvalidNumCols")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class FilterMixin { }
|
||||
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = NumericalColumnComparator.class, name = "NumericalColumnComparator"),
|
||||
@JsonSubTypes.Type(value = StringComparator.class, name = "StringComparator")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class SequenceComparatorMixin { }
|
||||
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = SequenceSplitTimeSeparation.class, name = "SequenceSplitTimeSeparation"),
|
||||
@JsonSubTypes.Type(value = SplitMaxLengthSequence.class, name = "SplitMaxLengthSequence")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class SequenceSplitMixin { }
|
||||
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = TimeWindowFunction.class, name = "TimeWindowFunction"),
|
||||
@JsonSubTypes.Type(value = OverlappingTimeWindowFunction.class, name = "OverlappingTimeWindowFunction")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class WindowFunctionMixin { }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class CalculateSortedRankMixin { }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = DoubleWritableComparator.class, name = "DoubleWritableComparator"),
|
||||
@JsonSubTypes.Type(value = FloatWritableComparator.class, name = "FloatWritableComparator"),
|
||||
@JsonSubTypes.Type(value = IntWritableComparator.class, name = "IntWritableComparator"),
|
||||
@JsonSubTypes.Type(value = LongWritableComparator.class, name = "LongWritableComparator"),
|
||||
@JsonSubTypes.Type(value = TextWritableComparator.class, name = "TextWritableComparator")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class WritableComparatorMixin { }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = BytesAnalysis.class, name = "BytesAnalysis"),
|
||||
@JsonSubTypes.Type(value = CategoricalAnalysis.class, name = "CategoricalAnalysis"),
|
||||
@JsonSubTypes.Type(value = DoubleAnalysis.class, name = "DoubleAnalysis"),
|
||||
@JsonSubTypes.Type(value = IntegerAnalysis.class, name = "IntegerAnalysis"),
|
||||
@JsonSubTypes.Type(value = LongAnalysis.class, name = "LongAnalysis"),
|
||||
@JsonSubTypes.Type(value = StringAnalysis.class, name = "StringAnalysis"),
|
||||
@JsonSubTypes.Type(value = TimeAnalysis.class, name = "TimeAnalysis")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class ColumnAnalysisMixin{ }
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
|
||||
@JsonSubTypes(value = {@JsonSubTypes.Type(value = StringReducer.class, name = "StringReducer")})
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class IStringReducerMixin{ }
|
||||
}
|
|
@ -1,535 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.api.transform.serde.legacy;
|
||||
|
||||
import org.datavec.api.transform.Transform;
|
||||
import org.datavec.api.transform.analysis.columns.*;
|
||||
import org.datavec.api.transform.condition.BooleanCondition;
|
||||
import org.datavec.api.transform.condition.Condition;
|
||||
import org.datavec.api.transform.condition.column.*;
|
||||
import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
|
||||
import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
||||
import org.datavec.api.transform.filter.ConditionFilter;
|
||||
import org.datavec.api.transform.filter.Filter;
|
||||
import org.datavec.api.transform.filter.FilterInvalidValues;
|
||||
import org.datavec.api.transform.filter.InvalidNumColumns;
|
||||
import org.datavec.api.transform.metadata.*;
|
||||
import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
|
||||
import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
|
||||
import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
|
||||
import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
|
||||
import org.datavec.api.transform.rank.CalculateSortedRank;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.schema.SequenceSchema;
|
||||
import org.datavec.api.transform.sequence.ReduceSequenceTransform;
|
||||
import org.datavec.api.transform.sequence.SequenceComparator;
|
||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
||||
import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
|
||||
import org.datavec.api.transform.sequence.comparator.StringComparator;
|
||||
import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
|
||||
import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
|
||||
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
|
||||
import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
|
||||
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
|
||||
import org.datavec.api.transform.sequence.window.TimeWindowFunction;
|
||||
import org.datavec.api.transform.sequence.window.WindowFunction;
|
||||
import org.datavec.api.transform.stringreduce.IStringReducer;
|
||||
import org.datavec.api.transform.stringreduce.StringReducer;
|
||||
import org.datavec.api.transform.transform.categorical.*;
|
||||
import org.datavec.api.transform.transform.column.*;
|
||||
import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
|
||||
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
|
||||
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
|
||||
import org.datavec.api.transform.transform.doubletransform.*;
|
||||
import org.datavec.api.transform.transform.integer.*;
|
||||
import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
|
||||
import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
|
||||
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
|
||||
import org.datavec.api.transform.transform.nlp.TextToTermIndexSequenceTransform;
|
||||
import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
|
||||
import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
|
||||
import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
|
||||
import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
|
||||
import org.datavec.api.transform.transform.string.*;
|
||||
import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
|
||||
import org.datavec.api.transform.transform.time.StringToTimeTransform;
|
||||
import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.api.writable.comparator.*;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class LegacyMappingHelper {
|
||||
|
||||
public static Map<String,String> legacyMappingForClass(Class c){
|
||||
//Need to be able to get the map - and they need to be mutable...
|
||||
switch (c.getSimpleName()){
|
||||
case "Transform":
|
||||
return getLegacyMappingImageTransform();
|
||||
case "ColumnAnalysis":
|
||||
return getLegacyMappingColumnAnalysis();
|
||||
case "Condition":
|
||||
return getLegacyMappingCondition();
|
||||
case "Filter":
|
||||
return getLegacyMappingFilter();
|
||||
case "ColumnMetaData":
|
||||
return mapColumnMetaData;
|
||||
case "CalculateSortedRank":
|
||||
return mapCalculateSortedRank;
|
||||
case "Schema":
|
||||
return mapSchema;
|
||||
case "SequenceComparator":
|
||||
return mapSequenceComparator;
|
||||
case "SequenceSplit":
|
||||
return mapSequenceSplit;
|
||||
case "WindowFunction":
|
||||
return mapWindowFunction;
|
||||
case "IStringReducer":
|
||||
return mapIStringReducer;
|
||||
case "Writable":
|
||||
return mapWritable;
|
||||
case "WritableComparator":
|
||||
return mapWritableComparator;
|
||||
case "ImageTransform":
|
||||
return mapImageTransform;
|
||||
default:
|
||||
//Should never happen
|
||||
throw new IllegalArgumentException("No legacy mapping available for class " + c.getName());
|
||||
}
|
||||
}
|
||||
|
||||
private static Map<String,String> mapTransform;
|
||||
private static Map<String,String> mapColumnAnalysis;
|
||||
private static Map<String,String> mapCondition;
|
||||
private static Map<String,String> mapFilter;
|
||||
private static Map<String,String> mapColumnMetaData;
|
||||
private static Map<String,String> mapCalculateSortedRank;
|
||||
private static Map<String,String> mapSchema;
|
||||
private static Map<String,String> mapSequenceComparator;
|
||||
private static Map<String,String> mapSequenceSplit;
|
||||
private static Map<String,String> mapWindowFunction;
|
||||
private static Map<String,String> mapIStringReducer;
|
||||
private static Map<String,String> mapWritable;
|
||||
private static Map<String,String> mapWritableComparator;
|
||||
private static Map<String,String> mapImageTransform;
|
||||
|
||||
private static synchronized Map<String,String> getLegacyMappingTransform(){
|
||||
|
||||
if(mapTransform == null) {
|
||||
//The following classes all used their class short name
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("CategoricalToIntegerTransform", CategoricalToIntegerTransform.class.getName());
|
||||
m.put("CategoricalToOneHotTransform", CategoricalToOneHotTransform.class.getName());
|
||||
m.put("IntegerToCategoricalTransform", IntegerToCategoricalTransform.class.getName());
|
||||
m.put("StringToCategoricalTransform", StringToCategoricalTransform.class.getName());
|
||||
m.put("DuplicateColumnsTransform", DuplicateColumnsTransform.class.getName());
|
||||
m.put("RemoveColumnsTransform", RemoveColumnsTransform.class.getName());
|
||||
m.put("RenameColumnsTransform", RenameColumnsTransform.class.getName());
|
||||
m.put("ReorderColumnsTransform", ReorderColumnsTransform.class.getName());
|
||||
m.put("ConditionalCopyValueTransform", ConditionalCopyValueTransform.class.getName());
|
||||
m.put("ConditionalReplaceValueTransform", ConditionalReplaceValueTransform.class.getName());
|
||||
m.put("ConditionalReplaceValueTransformWithDefault", ConditionalReplaceValueTransformWithDefault.class.getName());
|
||||
m.put("DoubleColumnsMathOpTransform", DoubleColumnsMathOpTransform.class.getName());
|
||||
m.put("DoubleMathOpTransform", DoubleMathOpTransform.class.getName());
|
||||
m.put("Log2Normalizer", Log2Normalizer.class.getName());
|
||||
m.put("MinMaxNormalizer", MinMaxNormalizer.class.getName());
|
||||
m.put("StandardizeNormalizer", StandardizeNormalizer.class.getName());
|
||||
m.put("SubtractMeanNormalizer", SubtractMeanNormalizer.class.getName());
|
||||
m.put("IntegerColumnsMathOpTransform", IntegerColumnsMathOpTransform.class.getName());
|
||||
m.put("IntegerMathOpTransform", IntegerMathOpTransform.class.getName());
|
||||
m.put("ReplaceEmptyIntegerWithValueTransform", ReplaceEmptyIntegerWithValueTransform.class.getName());
|
||||
m.put("ReplaceInvalidWithIntegerTransform", ReplaceInvalidWithIntegerTransform.class.getName());
|
||||
m.put("LongColumnsMathOpTransform", LongColumnsMathOpTransform.class.getName());
|
||||
m.put("LongMathOpTransform", LongMathOpTransform.class.getName());
|
||||
m.put("MapAllStringsExceptListTransform", MapAllStringsExceptListTransform.class.getName());
|
||||
m.put("RemoveWhiteSpaceTransform", RemoveWhiteSpaceTransform.class.getName());
|
||||
m.put("ReplaceEmptyStringTransform", ReplaceEmptyStringTransform.class.getName());
|
||||
m.put("ReplaceStringTransform", ReplaceStringTransform.class.getName());
|
||||
m.put("StringListToCategoricalSetTransform", StringListToCategoricalSetTransform.class.getName());
|
||||
m.put("StringMapTransform", StringMapTransform.class.getName());
|
||||
m.put("DeriveColumnsFromTimeTransform", DeriveColumnsFromTimeTransform.class.getName());
|
||||
m.put("StringToTimeTransform", StringToTimeTransform.class.getName());
|
||||
m.put("TimeMathOpTransform", TimeMathOpTransform.class.getName());
|
||||
m.put("ReduceSequenceByWindowTransform", ReduceSequenceByWindowTransform.class.getName());
|
||||
m.put("DoubleMathFunctionTransform", DoubleMathFunctionTransform.class.getName());
|
||||
m.put("AddConstantColumnTransform", AddConstantColumnTransform.class.getName());
|
||||
m.put("RemoveAllColumnsExceptForTransform", RemoveAllColumnsExceptForTransform.class.getName());
|
||||
m.put("ParseDoubleTransform", ParseDoubleTransform.class.getName());
|
||||
m.put("ConvertToStringTransform", ConvertToString.class.getName());
|
||||
m.put("AppendStringColumnTransform", AppendStringColumnTransform.class.getName());
|
||||
m.put("SequenceDifferenceTransform", SequenceDifferenceTransform.class.getName());
|
||||
m.put("ReduceSequenceTransform", ReduceSequenceTransform.class.getName());
|
||||
m.put("SequenceMovingWindowReduceTransform", SequenceMovingWindowReduceTransform.class.getName());
|
||||
m.put("IntegerToOneHotTransform", IntegerToOneHotTransform.class.getName());
|
||||
m.put("SequenceTrimTransform", SequenceTrimTransform.class.getName());
|
||||
m.put("SequenceOffsetTransform", SequenceOffsetTransform.class.getName());
|
||||
m.put("NDArrayColumnsMathOpTransform", NDArrayColumnsMathOpTransform.class.getName());
|
||||
m.put("NDArrayDistanceTransform", NDArrayDistanceTransform.class.getName());
|
||||
m.put("NDArrayMathFunctionTransform", NDArrayMathFunctionTransform.class.getName());
|
||||
m.put("NDArrayScalarOpTransform", NDArrayScalarOpTransform.class.getName());
|
||||
m.put("ChangeCaseStringTransform", ChangeCaseStringTransform.class.getName());
|
||||
m.put("ConcatenateStringColumns", ConcatenateStringColumns.class.getName());
|
||||
m.put("StringListToCountsNDArrayTransform", StringListToCountsNDArrayTransform.class.getName());
|
||||
m.put("StringListToIndicesNDArrayTransform", StringListToIndicesNDArrayTransform.class.getName());
|
||||
m.put("PivotTransform", PivotTransform.class.getName());
|
||||
m.put("TextToCharacterIndexTransform", TextToCharacterIndexTransform.class.getName());
|
||||
|
||||
//The following never had subtype annotations, and hence will have had the default name:
|
||||
m.put(TextToTermIndexSequenceTransform.class.getSimpleName(), TextToTermIndexSequenceTransform.class.getName());
|
||||
m.put(ConvertToInteger.class.getSimpleName(), ConvertToInteger.class.getName());
|
||||
m.put(ConvertToDouble.class.getSimpleName(), ConvertToDouble.class.getName());
|
||||
|
||||
mapTransform = m;
|
||||
}
|
||||
|
||||
return mapTransform;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingColumnAnalysis(){
|
||||
if(mapColumnAnalysis == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("BytesAnalysis", BytesAnalysis.class.getName());
|
||||
m.put("CategoricalAnalysis", CategoricalAnalysis.class.getName());
|
||||
m.put("DoubleAnalysis", DoubleAnalysis.class.getName());
|
||||
m.put("IntegerAnalysis", IntegerAnalysis.class.getName());
|
||||
m.put("LongAnalysis", LongAnalysis.class.getName());
|
||||
m.put("StringAnalysis", StringAnalysis.class.getName());
|
||||
m.put("TimeAnalysis", TimeAnalysis.class.getName());
|
||||
|
||||
//The following never had subtype annotations, and hence will have had the default name:
|
||||
m.put(NDArrayAnalysis.class.getSimpleName(), NDArrayAnalysis.class.getName());
|
||||
|
||||
mapColumnAnalysis = m;
|
||||
}
|
||||
|
||||
return mapColumnAnalysis;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingCondition(){
|
||||
if(mapCondition == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("TrivialColumnCondition", TrivialColumnCondition.class.getName());
|
||||
m.put("CategoricalColumnCondition", CategoricalColumnCondition.class.getName());
|
||||
m.put("DoubleColumnCondition", DoubleColumnCondition.class.getName());
|
||||
m.put("IntegerColumnCondition", IntegerColumnCondition.class.getName());
|
||||
m.put("LongColumnCondition", LongColumnCondition.class.getName());
|
||||
m.put("NullWritableColumnCondition", NullWritableColumnCondition.class.getName());
|
||||
m.put("StringColumnCondition", StringColumnCondition.class.getName());
|
||||
m.put("TimeColumnCondition", TimeColumnCondition.class.getName());
|
||||
m.put("StringRegexColumnCondition", StringRegexColumnCondition.class.getName());
|
||||
m.put("BooleanCondition", BooleanCondition.class.getName());
|
||||
m.put("NaNColumnCondition", NaNColumnCondition.class.getName());
|
||||
m.put("InfiniteColumnCondition", InfiniteColumnCondition.class.getName());
|
||||
m.put("SequenceLengthCondition", SequenceLengthCondition.class.getName());
|
||||
|
||||
//The following never had subtype annotations, and hence will have had the default name:
|
||||
m.put(InvalidValueColumnCondition.class.getSimpleName(), InvalidValueColumnCondition.class.getName());
|
||||
m.put(BooleanColumnCondition.class.getSimpleName(), BooleanColumnCondition.class.getName());
|
||||
|
||||
mapCondition = m;
|
||||
}
|
||||
|
||||
return mapCondition;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingFilter(){
|
||||
if(mapFilter == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("ConditionFilter", ConditionFilter.class.getName());
|
||||
m.put("FilterInvalidValues", FilterInvalidValues.class.getName());
|
||||
m.put("InvalidNumCols", InvalidNumColumns.class.getName());
|
||||
|
||||
mapFilter = m;
|
||||
}
|
||||
return mapFilter;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingColumnMetaData(){
|
||||
if(mapColumnMetaData == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("Categorical", CategoricalMetaData.class.getName());
|
||||
m.put("Double", DoubleMetaData.class.getName());
|
||||
m.put("Float", FloatMetaData.class.getName());
|
||||
m.put("Integer", IntegerMetaData.class.getName());
|
||||
m.put("Long", LongMetaData.class.getName());
|
||||
m.put("String", StringMetaData.class.getName());
|
||||
m.put("Time", TimeMetaData.class.getName());
|
||||
m.put("NDArray", NDArrayMetaData.class.getName());
|
||||
|
||||
//The following never had subtype annotations, and hence will have had the default name:
|
||||
m.put(BooleanMetaData.class.getSimpleName(), BooleanMetaData.class.getName());
|
||||
m.put(BinaryMetaData.class.getSimpleName(), BinaryMetaData.class.getName());
|
||||
|
||||
mapColumnMetaData = m;
|
||||
}
|
||||
|
||||
return mapColumnMetaData;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingCalculateSortedRank(){
|
||||
if(mapCalculateSortedRank == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("CalculateSortedRank", CalculateSortedRank.class.getName());
|
||||
mapCalculateSortedRank = m;
|
||||
}
|
||||
return mapCalculateSortedRank;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingSchema(){
|
||||
if(mapSchema == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("Schema", Schema.class.getName());
|
||||
m.put("SequenceSchema", SequenceSchema.class.getName());
|
||||
|
||||
mapSchema = m;
|
||||
}
|
||||
return mapSchema;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingSequenceComparator(){
|
||||
if(mapSequenceComparator == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("NumericalColumnComparator", NumericalColumnComparator.class.getName());
|
||||
m.put("StringComparator", StringComparator.class.getName());
|
||||
|
||||
mapSequenceComparator = m;
|
||||
}
|
||||
return mapSequenceComparator;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingSequenceSplit(){
|
||||
if(mapSequenceSplit == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("SequenceSplitTimeSeparation", SequenceSplitTimeSeparation.class.getName());
|
||||
m.put("SplitMaxLengthSequence", SplitMaxLengthSequence.class.getName());
|
||||
|
||||
mapSequenceSplit = m;
|
||||
}
|
||||
return mapSequenceSplit;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingWindowFunction(){
|
||||
if(mapWindowFunction == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("TimeWindowFunction", TimeWindowFunction.class.getName());
|
||||
m.put("OverlappingTimeWindowFunction", OverlappingTimeWindowFunction.class.getName());
|
||||
|
||||
mapWindowFunction = m;
|
||||
}
|
||||
return mapWindowFunction;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingIStringReducer(){
|
||||
if(mapIStringReducer == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("StringReducer", StringReducer.class.getName());
|
||||
|
||||
mapIStringReducer = m;
|
||||
}
|
||||
return mapIStringReducer;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingWritable(){
|
||||
if (mapWritable == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("ArrayWritable", ArrayWritable.class.getName());
|
||||
m.put("BooleanWritable", BooleanWritable.class.getName());
|
||||
m.put("ByteWritable", ByteWritable.class.getName());
|
||||
m.put("DoubleWritable", DoubleWritable.class.getName());
|
||||
m.put("FloatWritable", FloatWritable.class.getName());
|
||||
m.put("IntWritable", IntWritable.class.getName());
|
||||
m.put("LongWritable", LongWritable.class.getName());
|
||||
m.put("NullWritable", NullWritable.class.getName());
|
||||
m.put("Text", Text.class.getName());
|
||||
m.put("BytesWritable", BytesWritable.class.getName());
|
||||
|
||||
//The following never had subtype annotations, and hence will have had the default name:
|
||||
m.put(NDArrayWritable.class.getSimpleName(), NDArrayWritable.class.getName());
|
||||
|
||||
mapWritable = m;
|
||||
}
|
||||
|
||||
return mapWritable;
|
||||
}
|
||||
|
||||
private static Map<String,String> getLegacyMappingWritableComparator(){
|
||||
if(mapWritableComparator == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("DoubleWritableComparator", DoubleWritableComparator.class.getName());
|
||||
m.put("FloatWritableComparator", FloatWritableComparator.class.getName());
|
||||
m.put("IntWritableComparator", IntWritableComparator.class.getName());
|
||||
m.put("LongWritableComparator", LongWritableComparator.class.getName());
|
||||
m.put("TextWritableComparator", TextWritableComparator.class.getName());
|
||||
|
||||
//The following never had subtype annotations, and hence will have had the default name:
|
||||
m.put(ByteWritable.Comparator.class.getSimpleName(), ByteWritable.Comparator.class.getName());
|
||||
m.put(FloatWritable.Comparator.class.getSimpleName(), FloatWritable.Comparator.class.getName());
|
||||
m.put(IntWritable.Comparator.class.getSimpleName(), IntWritable.Comparator.class.getName());
|
||||
m.put(BooleanWritable.Comparator.class.getSimpleName(), BooleanWritable.Comparator.class.getName());
|
||||
m.put(LongWritable.Comparator.class.getSimpleName(), LongWritable.Comparator.class.getName());
|
||||
m.put(Text.Comparator.class.getSimpleName(), Text.Comparator.class.getName());
|
||||
m.put(LongWritable.DecreasingComparator.class.getSimpleName(), LongWritable.DecreasingComparator.class.getName());
|
||||
m.put(DoubleWritable.Comparator.class.getSimpleName(), DoubleWritable.Comparator.class.getName());
|
||||
|
||||
mapWritableComparator = m;
|
||||
}
|
||||
|
||||
return mapWritableComparator;
|
||||
}
|
||||
|
||||
public static Map<String,String> getLegacyMappingImageTransform(){
|
||||
if(mapImageTransform == null) {
|
||||
Map<String, String> m = new HashMap<>();
|
||||
m.put("EqualizeHistTransform", "org.datavec.image.transform.EqualizeHistTransform");
|
||||
m.put("RotateImageTransform", "org.datavec.image.transform.RotateImageTransform");
|
||||
m.put("ColorConversionTransform", "org.datavec.image.transform.ColorConversionTransform");
|
||||
m.put("WarpImageTransform", "org.datavec.image.transform.WarpImageTransform");
|
||||
m.put("BoxImageTransform", "org.datavec.image.transform.BoxImageTransform");
|
||||
m.put("CropImageTransform", "org.datavec.image.transform.CropImageTransform");
|
||||
m.put("FilterImageTransform", "org.datavec.image.transform.FilterImageTransform");
|
||||
m.put("FlipImageTransform", "org.datavec.image.transform.FlipImageTransform");
|
||||
m.put("LargestBlobCropTransform", "org.datavec.image.transform.LargestBlobCropTransform");
|
||||
m.put("ResizeImageTransform", "org.datavec.image.transform.ResizeImageTransform");
|
||||
m.put("RandomCropTransform", "org.datavec.image.transform.RandomCropTransform");
|
||||
m.put("ScaleImageTransform", "org.datavec.image.transform.ScaleImageTransform");
|
||||
|
||||
mapImageTransform = m;
|
||||
}
|
||||
return mapImageTransform;
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyTransformDeserializer.class)
|
||||
public static class TransformHelper { }
|
||||
|
||||
public static class LegacyTransformDeserializer extends GenericLegacyDeserializer<Transform> {
|
||||
public LegacyTransformDeserializer() {
|
||||
super(Transform.class, getLegacyMappingTransform());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyColumnAnalysisDeserializer.class)
|
||||
public static class ColumnAnalysisHelper { }
|
||||
|
||||
public static class LegacyColumnAnalysisDeserializer extends GenericLegacyDeserializer<ColumnAnalysis> {
|
||||
public LegacyColumnAnalysisDeserializer() {
|
||||
super(ColumnAnalysis.class, getLegacyMappingColumnAnalysis());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyConditionDeserializer.class)
|
||||
public static class ConditionHelper { }
|
||||
|
||||
public static class LegacyConditionDeserializer extends GenericLegacyDeserializer<Condition> {
|
||||
public LegacyConditionDeserializer() {
|
||||
super(Condition.class, getLegacyMappingCondition());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyFilterDeserializer.class)
|
||||
public static class FilterHelper { }
|
||||
|
||||
public static class LegacyFilterDeserializer extends GenericLegacyDeserializer<Filter> {
|
||||
public LegacyFilterDeserializer() {
|
||||
super(Filter.class, getLegacyMappingFilter());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyColumnMetaDataDeserializer.class)
|
||||
public static class ColumnMetaDataHelper { }
|
||||
|
||||
public static class LegacyColumnMetaDataDeserializer extends GenericLegacyDeserializer<ColumnMetaData> {
|
||||
public LegacyColumnMetaDataDeserializer() {
|
||||
super(ColumnMetaData.class, getLegacyMappingColumnMetaData());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyCalculateSortedRankDeserializer.class)
|
||||
public static class CalculateSortedRankHelper { }
|
||||
|
||||
public static class LegacyCalculateSortedRankDeserializer extends GenericLegacyDeserializer<CalculateSortedRank> {
|
||||
public LegacyCalculateSortedRankDeserializer() {
|
||||
super(CalculateSortedRank.class, getLegacyMappingCalculateSortedRank());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacySchemaDeserializer.class)
|
||||
public static class SchemaHelper { }
|
||||
|
||||
public static class LegacySchemaDeserializer extends GenericLegacyDeserializer<Schema> {
|
||||
public LegacySchemaDeserializer() {
|
||||
super(Schema.class, getLegacyMappingSchema());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacySequenceComparatorDeserializer.class)
|
||||
public static class SequenceComparatorHelper { }
|
||||
|
||||
public static class LegacySequenceComparatorDeserializer extends GenericLegacyDeserializer<SequenceComparator> {
|
||||
public LegacySequenceComparatorDeserializer() {
|
||||
super(SequenceComparator.class, getLegacyMappingSequenceComparator());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacySequenceSplitDeserializer.class)
|
||||
public static class SequenceSplitHelper { }
|
||||
|
||||
public static class LegacySequenceSplitDeserializer extends GenericLegacyDeserializer<SequenceSplit> {
|
||||
public LegacySequenceSplitDeserializer() {
|
||||
super(SequenceSplit.class, getLegacyMappingSequenceSplit());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyWindowFunctionDeserializer.class)
|
||||
public static class WindowFunctionHelper { }
|
||||
|
||||
public static class LegacyWindowFunctionDeserializer extends GenericLegacyDeserializer<WindowFunction> {
|
||||
public LegacyWindowFunctionDeserializer() {
|
||||
super(WindowFunction.class, getLegacyMappingWindowFunction());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@JsonDeserialize(using = LegacyIStringReducerDeserializer.class)
|
||||
public static class IStringReducerHelper { }
|
||||
|
||||
public static class LegacyIStringReducerDeserializer extends GenericLegacyDeserializer<IStringReducer> {
|
||||
public LegacyIStringReducerDeserializer() {
|
||||
super(IStringReducer.class, getLegacyMappingIStringReducer());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@JsonDeserialize(using = LegacyWritableDeserializer.class)
|
||||
public static class WritableHelper { }
|
||||
|
||||
public static class LegacyWritableDeserializer extends GenericLegacyDeserializer<Writable> {
|
||||
public LegacyWritableDeserializer() {
|
||||
super(Writable.class, getLegacyMappingWritable());
|
||||
}
|
||||
}
|
||||
|
||||
@JsonDeserialize(using = LegacyWritableComparatorDeserializer.class)
|
||||
public static class WritableComparatorHelper { }
|
||||
|
||||
public static class LegacyWritableComparatorDeserializer extends GenericLegacyDeserializer<WritableComparator> {
|
||||
public LegacyWritableComparatorDeserializer() {
|
||||
super(WritableComparator.class, getLegacyMappingWritableComparator());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,7 +17,6 @@
|
|||
package org.datavec.api.transform.stringreduce;
|
||||
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -31,8 +30,7 @@ import java.util.List;
|
|||
* a single List<Writable>
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.IStringReducerHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface IStringReducer extends Serializable {
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.api.util.ndarray;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.nd4j.shade.guava.base.Preconditions;
|
||||
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
|
||||
import lombok.NonNull;
|
||||
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
|
||||
import com.google.common.math.DoubleMath;
|
||||
import org.nd4j.shade.guava.math.DoubleMath;
|
||||
import org.datavec.api.io.WritableComparable;
|
||||
import org.datavec.api.io.WritableComparator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
|
||||
import com.google.common.math.DoubleMath;
|
||||
import org.nd4j.shade.guava.math.DoubleMath;
|
||||
import org.datavec.api.io.WritableComparable;
|
||||
import org.datavec.api.io.WritableComparator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
|
||||
import com.google.common.math.DoubleMath;
|
||||
import org.nd4j.shade.guava.math.DoubleMath;
|
||||
import org.datavec.api.io.WritableComparable;
|
||||
import org.datavec.api.io.WritableComparator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
|
||||
import com.google.common.math.DoubleMath;
|
||||
import org.nd4j.shade.guava.math.DoubleMath;
|
||||
import org.datavec.api.io.WritableComparable;
|
||||
import org.datavec.api.io.WritableComparator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
|
||||
import com.google.common.math.DoubleMath;
|
||||
import org.nd4j.shade.guava.math.DoubleMath;
|
||||
import org.datavec.api.io.WritableComparable;
|
||||
import org.datavec.api.io.WritableComparator;
|
||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
package org.datavec.api.writable;
|
||||
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.DataInput;
|
||||
|
@ -60,8 +59,7 @@ import java.io.Serializable;
|
|||
* }
|
||||
* </pre></blockquote></p>
|
||||
*/
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.WritableHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface Writable extends Serializable {
|
||||
/**
|
||||
* Serialize the fields of this object to <code>out</code>.
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.api.writable.batch;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.nd4j.shade.guava.base.Preconditions;
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
|
|
|
@ -16,16 +16,13 @@
|
|||
|
||||
package org.datavec.api.writable.comparator;
|
||||
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Comparator;
|
||||
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyMappingHelper.WritableComparatorHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface WritableComparator extends Comparator<Writable>, Serializable {
|
||||
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.api.split;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.datavec.api.io.filters.BalancedPathFilter;
|
||||
import org.datavec.api.io.filters.RandomPathFilter;
|
||||
import org.datavec.api.io.labels.ParentPathLabelGenerator;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.api.split.parittion;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||
|
|
|
@ -78,8 +78,9 @@ public class TestJsonYaml {
|
|||
public void testMissingPrimitives() {
|
||||
|
||||
Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build();
|
||||
|
||||
String strJson = "{\n" + " \"Schema\" : {\n" + " \"columns\" : [ {\n" + " \"Double\" : {\n"
|
||||
//Legacy format JSON
|
||||
String strJson = "{\n" + " \"Schema\" : {\n"
|
||||
+ " \"columns\" : [ {\n" + " \"Double\" : {\n"
|
||||
+ " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" +
|
||||
//" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test
|
||||
//" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.api.writable;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.nd4j.shade.guava.collect.Lists;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
import org.junit.Test;
|
||||
|
|
|
@ -34,42 +34,6 @@
|
|||
<artifactId>nd4j-arrow</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${spark2.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${spark2.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${spark2.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-yaml</artifactId>
|
||||
<version>${spark2.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-xml</artifactId>
|
||||
<version>${spark2.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-joda</artifactId>
|
||||
<version>${spark2.jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
|
@ -80,11 +44,6 @@
|
|||
<artifactId>hppc</artifactId>
|
||||
<version>${hppc.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>${guava.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.arrow</groupId>
|
||||
<artifactId>arrow-vector</artifactId>
|
||||
|
|
|
@ -43,12 +43,6 @@
|
|||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Do not depend on FFmpeg by default due to licensing concerns. -->
|
||||
<!--
|
||||
|
@ -60,7 +54,6 @@
|
|||
-->
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
|
|
|
@ -31,20 +31,12 @@
|
|||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-buffer</artifactId>
|
||||
|
@ -75,7 +67,6 @@
|
|||
<artifactId>imageio-bmp</artifactId>
|
||||
<version>3.1.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.android</groupId>
|
||||
<artifactId>android</artifactId>
|
||||
|
@ -88,7 +79,6 @@
|
|||
</exclusions>
|
||||
<optional>true</optional>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>javacpp</artifactId>
|
||||
|
@ -99,25 +89,21 @@
|
|||
<artifactId>javacv</artifactId>
|
||||
<version>${javacv.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>opencv-platform</artifactId>
|
||||
<version>${opencv.version}-${javacpp-presets.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>leptonica-platform</artifactId>
|
||||
<version>${leptonica.version}-${javacpp-presets.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>hdf5-platform</artifactId>
|
||||
<version>${hdf5.version}-${javacpp-presets.version}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
@ -143,5 +129,4 @@
|
|||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
</project>
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.image.recordreader;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.nd4j.shade.guava.base.Preconditions;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.image.serde;
|
||||
|
||||
import org.datavec.api.transform.serde.legacy.GenericLegacyDeserializer;
|
||||
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
|
||||
import org.datavec.image.transform.ImageTransform;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
|
||||
public class LegacyImageMappingHelper {
|
||||
|
||||
@JsonDeserialize(using = LegacyImageTransformDeserializer.class)
|
||||
public static class ImageTransformHelper { }
|
||||
|
||||
public static class LegacyImageTransformDeserializer extends GenericLegacyDeserializer<ImageTransform> {
|
||||
public LegacyImageTransformDeserializer() {
|
||||
super(ImageTransform.class, LegacyMappingHelper.getLegacyMappingImageTransform());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -16,11 +16,8 @@
|
|||
|
||||
package org.datavec.image.transform;
|
||||
|
||||
import lombok.Data;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.datavec.image.serde.LegacyImageMappingHelper;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
||||
import java.util.Random;
|
||||
|
@ -31,8 +28,7 @@ import java.util.Random;
|
|||
* @author saudet
|
||||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class",
|
||||
defaultImpl = LegacyImageMappingHelper.ImageTransformHelper.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ImageTransform {
|
||||
|
||||
/**
|
||||
|
|
|
@ -31,7 +31,6 @@
|
|||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<cleartk.version>2.0.0</cleartk.version>
|
||||
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
|
@ -75,13 +74,6 @@
|
|||
<artifactId>cleartk-opennlp-tools</artifactId>
|
||||
<version>${cleartk.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
|
|
|
@ -50,11 +50,6 @@
|
|||
<artifactId>netty</artifactId>
|
||||
<version>${netty.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>${guava.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-compress</artifactId>
|
||||
|
@ -95,14 +90,6 @@
|
|||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<!-- Test dependencies -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.hadoop.records.reader;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.io.*;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.hadoop.records.reader;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.io.*;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.hadoop.records.reader;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.hadoop.io.*;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.hadoop.records.writer;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.datavec.api.records.converter.RecordReaderConverter;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.SequenceRecordReader;
|
||||
|
|
|
@ -50,12 +50,6 @@
|
|||
<artifactId>protonpack</artifactId>
|
||||
<version>${protonpack.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.local.transforms.join;
|
||||
|
||||
import com.google.common.collect.Iterables;
|
||||
import org.nd4j.shade.guava.collect.Iterables;
|
||||
import org.datavec.api.transform.join.Join;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.local.transforms.functions.FlatMapFunctionAdapter;
|
||||
|
|
|
@ -52,12 +52,6 @@
|
|||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
|
|
|
@ -26,10 +26,8 @@
|
|||
|
||||
<artifactId>datavec-python</artifactId>
|
||||
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<dependency>
|
||||
<groupId>com.googlecode.json-simple</groupId>
|
||||
<artifactId>json-simple</artifactId>
|
||||
<version>1.1</version>
|
||||
|
@ -39,11 +37,6 @@
|
|||
<artifactId>cpython-platform</artifactId>
|
||||
<version>${cpython-platform.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
|
@ -54,14 +47,19 @@
|
|||
<artifactId>datavec-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
|
@ -70,5 +68,4 @@
|
|||
<id>test-nd4j-cuda-10.1</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
</project>
|
||||
|
|
|
@ -30,12 +30,6 @@
|
|||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-jackson</artifactId>
|
||||
|
|
|
@ -35,12 +35,6 @@
|
|||
<artifactId>datavec-api</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
|
|
|
@ -64,12 +64,6 @@
|
|||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-cluster_2.11</artifactId>
|
||||
<version>${akka.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
|
@ -106,40 +100,10 @@
|
|||
<version>${snakeyaml.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jdk8</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-java_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
<version>${playframework.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
|
@ -161,25 +125,31 @@
|
|||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-json_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-server_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-netty-server_2.11</artifactId>
|
||||
<version>${play.version}</version>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-cluster_2.11</artifactId>
|
||||
<version>2.5.23</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
@ -195,14 +165,11 @@
|
|||
<version>${jcommander.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Test Scope Dependencies -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_2.11</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -24,12 +24,16 @@ import org.apache.commons.io.FileUtils;
|
|||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.transform.model.*;
|
||||
import play.BuiltInComponents;
|
||||
import play.Mode;
|
||||
import play.routing.Router;
|
||||
import play.routing.RoutingDsl;
|
||||
import play.server.Server;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Base64;
|
||||
import java.util.Random;
|
||||
|
||||
import static play.mvc.Results.*;
|
||||
|
||||
|
@ -66,9 +70,6 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
System.exit(1);
|
||||
}
|
||||
|
||||
RoutingDsl routingDsl = new RoutingDsl();
|
||||
|
||||
|
||||
if (jsonPath != null) {
|
||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
||||
TransformProcess transformProcess = TransformProcess.fromJson(json);
|
||||
|
@ -78,8 +79,26 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
+ "to /transformprocess");
|
||||
}
|
||||
|
||||
//Set play secret key, if required
|
||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||
String crypto = System.getProperty("play.crypto.secret");
|
||||
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
|
||||
byte[] newCrypto = new byte[1024];
|
||||
|
||||
routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
||||
new Random().nextBytes(newCrypto);
|
||||
|
||||
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
||||
System.setProperty("play.crypto.secret", base64);
|
||||
}
|
||||
|
||||
|
||||
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
||||
}
|
||||
|
||||
protected Router createRouter(BuiltInComponents b){
|
||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
|
||||
|
||||
routingDsl.GET("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
if (transform == null)
|
||||
return badRequest();
|
||||
|
@ -88,11 +107,11 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
log.error("Error in GET /transformprocess",e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
||||
routingDsl.POST("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText());
|
||||
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
|
||||
setCSVTransformProcess(transformProcess);
|
||||
log.info("Transform process initialized");
|
||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
||||
|
@ -100,12 +119,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
log.error("Error in POST /transformprocess",e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformincremental").routeTo(FunctionUtil.function0((() -> {
|
||||
if (isSequence()) {
|
||||
routingDsl.POST("/transformincremental").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
|
||||
|
@ -115,7 +134,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
}
|
||||
} else {
|
||||
try {
|
||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class);
|
||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
|
||||
|
@ -124,12 +143,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transform").routeTo(FunctionUtil.function0((() -> {
|
||||
if (isSequence()) {
|
||||
routingDsl.POST("/transform").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class));
|
||||
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
|
||||
if (batch == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
|
||||
|
@ -139,7 +158,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
}
|
||||
} else {
|
||||
try {
|
||||
BatchCSVRecord input = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
||||
BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
BatchCSVRecord batch = transform(input);
|
||||
if (batch == null)
|
||||
return badRequest();
|
||||
|
@ -149,14 +168,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
})));
|
||||
|
||||
routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> {
|
||||
if (isSequence()) {
|
||||
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
|
||||
|
@ -166,7 +183,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
}
|
||||
} else {
|
||||
try {
|
||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class);
|
||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
|
||||
|
@ -175,13 +192,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
})));
|
||||
|
||||
routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> {
|
||||
if (isSequence()) {
|
||||
routingDsl.POST("/transformarray").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class);
|
||||
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
|
||||
if (batchCSVRecord == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
|
||||
|
@ -191,7 +207,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
}
|
||||
} else {
|
||||
try {
|
||||
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), BatchCSVRecord.class);
|
||||
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
if (batchCSVRecord == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
|
||||
|
@ -200,10 +216,9 @@ public class CSVSparkTransformServer extends SparkTransformServer {
|
|||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
|
||||
server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
|
||||
return routingDsl.build();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import play.libs.F;
|
||||
import play.mvc.Result;
|
||||
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
/**
|
||||
* Utility methods for Routing
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class FunctionUtil {
|
||||
|
||||
|
||||
public static F.Function0<Result> function0(Supplier<Result> supplier) {
|
||||
return supplier::get;
|
||||
}
|
||||
|
||||
public static <T> F.Function<T, Result> function(Function<T, Result> function) {
|
||||
return function::apply;
|
||||
}
|
||||
|
||||
}
|
|
@ -24,8 +24,11 @@ import org.apache.commons.io.FileUtils;
|
|||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.transform.model.*;
|
||||
import play.BuiltInComponents;
|
||||
import play.Mode;
|
||||
import play.libs.Files;
|
||||
import play.mvc.Http;
|
||||
import play.routing.Router;
|
||||
import play.routing.RoutingDsl;
|
||||
import play.server.Server;
|
||||
|
||||
|
@ -33,6 +36,7 @@ import java.io.File;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static play.mvc.Controller.request;
|
||||
import static play.mvc.Results.*;
|
||||
|
@ -62,8 +66,6 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
System.exit(1);
|
||||
}
|
||||
|
||||
RoutingDsl routingDsl = new RoutingDsl();
|
||||
|
||||
if (jsonPath != null) {
|
||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
|
||||
|
@ -73,7 +75,13 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
+ "to /transformprocess");
|
||||
}
|
||||
|
||||
routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
||||
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
||||
}
|
||||
|
||||
protected Router createRouter(BuiltInComponents builtInComponents){
|
||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
||||
|
||||
routingDsl.GET("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
if (transform == null)
|
||||
return badRequest();
|
||||
|
@ -83,11 +91,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
e.printStackTrace();
|
||||
return internalServerError();
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> {
|
||||
routingDsl.POST("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText());
|
||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
|
||||
setImageTransformProcess(transformProcess);
|
||||
log.info("Transform process initialized");
|
||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
||||
|
@ -95,11 +103,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
e.printStackTrace();
|
||||
return internalServerError();
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> {
|
||||
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
||||
try {
|
||||
SingleImageRecord record = objectMapper.readValue(getJsonText(), SingleImageRecord.class);
|
||||
SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||
|
@ -107,17 +115,17 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
e.printStackTrace();
|
||||
return internalServerError();
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformincrementalimage").routeTo(FunctionUtil.function0((() -> {
|
||||
routingDsl.POST("/transformincrementalimage").routingTo(req -> {
|
||||
try {
|
||||
Http.MultipartFormData body = request().body().asMultipartFormData();
|
||||
List<Http.MultipartFormData.FilePart> files = body.getFiles();
|
||||
if (files.size() == 0 || files.get(0).getFile() == null) {
|
||||
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
||||
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
||||
if (files.isEmpty() || files.get(0).getRef() == null ) {
|
||||
return badRequest();
|
||||
}
|
||||
|
||||
File file = files.get(0).getFile();
|
||||
File file = files.get(0).getRef().path().toFile();
|
||||
SingleImageRecord record = new SingleImageRecord(file.toURI());
|
||||
|
||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||
|
@ -125,11 +133,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
e.printStackTrace();
|
||||
return internalServerError();
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> {
|
||||
routingDsl.POST("/transformarray").routingTo(req -> {
|
||||
try {
|
||||
BatchImageRecord batch = objectMapper.readValue(getJsonText(), BatchImageRecord.class);
|
||||
BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
|
||||
if (batch == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
||||
|
@ -137,22 +145,22 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
e.printStackTrace();
|
||||
return internalServerError();
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformimage").routeTo(FunctionUtil.function0((() -> {
|
||||
routingDsl.POST("/transformimage").routingTo(req -> {
|
||||
try {
|
||||
Http.MultipartFormData body = request().body().asMultipartFormData();
|
||||
List<Http.MultipartFormData.FilePart> files = body.getFiles();
|
||||
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
||||
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
||||
if (files.size() == 0) {
|
||||
return badRequest();
|
||||
}
|
||||
|
||||
List<SingleImageRecord> records = new ArrayList<>();
|
||||
|
||||
for (Http.MultipartFormData.FilePart filePart : files) {
|
||||
File file = filePart.getFile();
|
||||
for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) {
|
||||
Files.TemporaryFile file = filePart.getRef();
|
||||
if (file != null) {
|
||||
SingleImageRecord record = new SingleImageRecord(file.toURI());
|
||||
SingleImageRecord record = new SingleImageRecord(file.path().toUri());
|
||||
records.add(record);
|
||||
}
|
||||
}
|
||||
|
@ -164,9 +172,9 @@ public class ImageSparkTransformServer extends SparkTransformServer {
|
|||
e.printStackTrace();
|
||||
return internalServerError();
|
||||
}
|
||||
})));
|
||||
});
|
||||
|
||||
server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
|
||||
return routingDsl.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody;
|
|||
import org.datavec.spark.transform.model.BatchCSVRecord;
|
||||
import org.datavec.spark.transform.service.DataVecTransformService;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import play.mvc.Http;
|
||||
import play.server.Server;
|
||||
|
||||
import static play.mvc.Controller.request;
|
||||
|
@ -50,25 +51,17 @@ public abstract class SparkTransformServer implements DataVecTransformService {
|
|||
server.stop();
|
||||
}
|
||||
|
||||
protected boolean isSequence() {
|
||||
return request().hasHeader(SEQUENCE_OR_NOT_HEADER)
|
||||
&& request().getHeader(SEQUENCE_OR_NOT_HEADER).toUpperCase()
|
||||
.equals("TRUE");
|
||||
protected boolean isSequence(Http.Request request) {
|
||||
return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
|
||||
&& request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
|
||||
}
|
||||
|
||||
|
||||
protected String getHeaderValue(String value) {
|
||||
if (request().hasHeader(value))
|
||||
return request().getHeader(value);
|
||||
return null;
|
||||
}
|
||||
|
||||
protected String getJsonText() {
|
||||
JsonNode tryJson = request().body().asJson();
|
||||
protected String getJsonText(Http.Request request) {
|
||||
JsonNode tryJson = request.body().asJson();
|
||||
if (tryJson != null)
|
||||
return tryJson.toString();
|
||||
else
|
||||
return request().body().asText();
|
||||
return request.body().asText();
|
||||
}
|
||||
|
||||
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
|
||||
|
|
|
@ -0,0 +1,350 @@
|
|||
# This is the main configuration file for the application.
|
||||
# https://www.playframework.com/documentation/latest/ConfigFile
|
||||
# ~~~~~
|
||||
# Play uses HOCON as its configuration file format. HOCON has a number
|
||||
# of advantages over other config formats, but there are two things that
|
||||
# can be used when modifying settings.
|
||||
#
|
||||
# You can include other configuration files in this main application.conf file:
|
||||
#include "extra-config.conf"
|
||||
#
|
||||
# You can declare variables and substitute for them:
|
||||
#mykey = ${some.value}
|
||||
#
|
||||
# And if an environment variable exists when there is no other subsitution, then
|
||||
# HOCON will fall back to substituting environment variable:
|
||||
#mykey = ${JAVA_HOME}
|
||||
|
||||
## Akka
|
||||
# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
|
||||
# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
|
||||
# ~~~~~
|
||||
# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
|
||||
# other streaming HTTP responses.
|
||||
akka {
|
||||
# "akka.log-config-on-start" is extraordinarly useful because it log the complete
|
||||
# configuration at INFO level, including defaults and overrides, so it s worth
|
||||
# putting at the very top.
|
||||
#
|
||||
# Put the following in your conf/logback.xml file:
|
||||
#
|
||||
# <logger name="akka.actor" level="INFO" />
|
||||
#
|
||||
# And then uncomment this line to debug the configuration.
|
||||
#
|
||||
#log-config-on-start = true
|
||||
}
|
||||
|
||||
## Modules
|
||||
# https://www.playframework.com/documentation/latest/Modules
|
||||
# ~~~~~
|
||||
# Control which modules are loaded when Play starts. Note that modules are
|
||||
# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
|
||||
# Please see https://www.playframework.com/documentation/latest/GlobalSettings
|
||||
# for more information.
|
||||
#
|
||||
# You can also extend Play functionality by using one of the publically available
|
||||
# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
|
||||
play.modules {
|
||||
# By default, Play will load any class called Module that is defined
|
||||
# in the root package (the "app" directory), or you can define them
|
||||
# explicitly below.
|
||||
# If there are any built-in modules that you want to disable, you can list them here.
|
||||
#enabled += my.application.Module
|
||||
|
||||
# If there are any built-in modules that you want to disable, you can list them here.
|
||||
#disabled += ""
|
||||
}
|
||||
|
||||
## Internationalisation
|
||||
# https://www.playframework.com/documentation/latest/JavaI18N
|
||||
# https://www.playframework.com/documentation/latest/ScalaI18N
|
||||
# ~~~~~
|
||||
# Play comes with its own i18n settings, which allow the user's preferred language
|
||||
# to map through to internal messages, or allow the language to be stored in a cookie.
|
||||
play.i18n {
|
||||
# The application languages
|
||||
langs = [ "en" ]
|
||||
|
||||
# Whether the language cookie should be secure or not
|
||||
#langCookieSecure = true
|
||||
|
||||
# Whether the HTTP only attribute of the cookie should be set to true
|
||||
#langCookieHttpOnly = true
|
||||
}
|
||||
|
||||
## Play HTTP settings
|
||||
# ~~~~~
|
||||
play.http {
|
||||
## Router
|
||||
# https://www.playframework.com/documentation/latest/JavaRouting
|
||||
# https://www.playframework.com/documentation/latest/ScalaRouting
|
||||
# ~~~~~
|
||||
# Define the Router object to use for this application.
|
||||
# This router will be looked up first when the application is starting up,
|
||||
# so make sure this is the entry point.
|
||||
# Furthermore, it's assumed your route file is named properly.
|
||||
# So for an application router like `my.application.Router`,
|
||||
# you may need to define a router file `conf/my.application.routes`.
|
||||
# Default to Routes in the root package (aka "apps" folder) (and conf/routes)
|
||||
#router = my.application.Router
|
||||
|
||||
## Action Creator
|
||||
# https://www.playframework.com/documentation/latest/JavaActionCreator
|
||||
# ~~~~~
|
||||
#actionCreator = null
|
||||
|
||||
## ErrorHandler
|
||||
# https://www.playframework.com/documentation/latest/JavaRouting
|
||||
# https://www.playframework.com/documentation/latest/ScalaRouting
|
||||
# ~~~~~
|
||||
# If null, will attempt to load a class called ErrorHandler in the root package,
|
||||
#errorHandler = null
|
||||
|
||||
## Filters
|
||||
# https://www.playframework.com/documentation/latest/ScalaHttpFilters
|
||||
# https://www.playframework.com/documentation/latest/JavaHttpFilters
|
||||
# ~~~~~
|
||||
# Filters run code on every request. They can be used to perform
|
||||
# common logic for all your actions, e.g. adding common headers.
|
||||
# Defaults to "Filters" in the root package (aka "apps" folder)
|
||||
# Alternatively you can explicitly register a class here.
|
||||
#filters += my.application.Filters
|
||||
|
||||
## Session & Flash
|
||||
# https://www.playframework.com/documentation/latest/JavaSessionFlash
|
||||
# https://www.playframework.com/documentation/latest/ScalaSessionFlash
|
||||
# ~~~~~
|
||||
session {
|
||||
# Sets the cookie to be sent only over HTTPS.
|
||||
#secure = true
|
||||
|
||||
# Sets the cookie to be accessed only by the server.
|
||||
#httpOnly = true
|
||||
|
||||
# Sets the max-age field of the cookie to 5 minutes.
|
||||
# NOTE: this only sets when the browser will discard the cookie. Play will consider any
|
||||
# cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
|
||||
# you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
|
||||
#maxAge = 300
|
||||
|
||||
# Sets the domain on the session cookie.
|
||||
#domain = "example.com"
|
||||
}
|
||||
|
||||
flash {
|
||||
# Sets the cookie to be sent only over HTTPS.
|
||||
#secure = true
|
||||
|
||||
# Sets the cookie to be accessed only by the server.
|
||||
#httpOnly = true
|
||||
}
|
||||
}
|
||||
|
||||
## Netty Provider
|
||||
# https://www.playframework.com/documentation/latest/SettingsNetty
|
||||
# ~~~~~
|
||||
play.server.netty {
|
||||
# Whether the Netty wire should be logged
|
||||
#log.wire = true
|
||||
|
||||
# If you run Play on Linux, you can use Netty's native socket transport
|
||||
# for higher performance with less garbage.
|
||||
#transport = "native"
|
||||
}
|
||||
|
||||
## WS (HTTP Client)
|
||||
# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
|
||||
# ~~~~~
|
||||
# The HTTP client primarily used for REST APIs. The default client can be
|
||||
# configured directly, but you can also create different client instances
|
||||
# with customized settings. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += ws // or javaWs if using java
|
||||
#
|
||||
play.ws {
|
||||
# Sets HTTP requests not to follow 302 requests
|
||||
#followRedirects = false
|
||||
|
||||
# Sets the maximum number of open HTTP connections for the client.
|
||||
#ahc.maxConnectionsTotal = 50
|
||||
|
||||
## WS SSL
|
||||
# https://www.playframework.com/documentation/latest/WsSSL
|
||||
# ~~~~~
|
||||
ssl {
|
||||
# Configuring HTTPS with Play WS does not require programming. You can
|
||||
# set up both trustManager and keyManager for mutual authentication, and
|
||||
# turn on JSSE debugging in development with a reload.
|
||||
#debug.handshake = true
|
||||
#trustManager = {
|
||||
# stores = [
|
||||
# { type = "JKS", path = "exampletrust.jks" }
|
||||
# ]
|
||||
#}
|
||||
}
|
||||
}
|
||||
|
||||
## Cache
|
||||
# https://www.playframework.com/documentation/latest/JavaCache
|
||||
# https://www.playframework.com/documentation/latest/ScalaCache
|
||||
# ~~~~~
|
||||
# Play comes with an integrated cache API that can reduce the operational
|
||||
# overhead of repeated requests. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += cache
|
||||
#
|
||||
play.cache {
|
||||
# If you want to bind several caches, you can bind the individually
|
||||
#bindCaches = ["db-cache", "user-cache", "session-cache"]
|
||||
}
|
||||
|
||||
## Filters
|
||||
# https://www.playframework.com/documentation/latest/Filters
|
||||
# ~~~~~
|
||||
# There are a number of built-in filters that can be enabled and configured
|
||||
# to give Play greater security. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += filters
|
||||
#
|
||||
play.filters {
|
||||
## CORS filter configuration
|
||||
# https://www.playframework.com/documentation/latest/CorsFilter
|
||||
# ~~~~~
|
||||
# CORS is a protocol that allows web applications to make requests from the browser
|
||||
# across different domains.
|
||||
# NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
|
||||
# dependencies on CORS settings.
|
||||
cors {
|
||||
# Filter paths by a whitelist of path prefixes
|
||||
#pathPrefixes = ["/some/path", ...]
|
||||
|
||||
# The allowed origins. If null, all origins are allowed.
|
||||
#allowedOrigins = ["http://www.example.com"]
|
||||
|
||||
# The allowed HTTP methods. If null, all methods are allowed
|
||||
#allowedHttpMethods = ["GET", "POST"]
|
||||
}
|
||||
|
||||
## CSRF Filter
|
||||
# https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
|
||||
# https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
|
||||
# ~~~~~
|
||||
# Play supports multiple methods for verifying that a request is not a CSRF request.
|
||||
# The primary mechanism is a CSRF token. This token gets placed either in the query string
|
||||
# or body of every form submitted, and also gets placed in the users session.
|
||||
# Play then verifies that both tokens are present and match.
|
||||
csrf {
|
||||
# Sets the cookie to be sent only over HTTPS
|
||||
#cookie.secure = true
|
||||
|
||||
# Defaults to CSRFErrorHandler in the root package.
|
||||
#errorHandler = MyCSRFErrorHandler
|
||||
}
|
||||
|
||||
## Security headers filter configuration
|
||||
# https://www.playframework.com/documentation/latest/SecurityHeaders
|
||||
# ~~~~~
|
||||
# Defines security headers that prevent XSS attacks.
|
||||
# If enabled, then all options are set to the below configuration by default:
|
||||
headers {
|
||||
# The X-Frame-Options header. If null, the header is not set.
|
||||
#frameOptions = "DENY"
|
||||
|
||||
# The X-XSS-Protection header. If null, the header is not set.
|
||||
#xssProtection = "1; mode=block"
|
||||
|
||||
# The X-Content-Type-Options header. If null, the header is not set.
|
||||
#contentTypeOptions = "nosniff"
|
||||
|
||||
# The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
|
||||
#permittedCrossDomainPolicies = "master-only"
|
||||
|
||||
# The Content-Security-Policy header. If null, the header is not set.
|
||||
#contentSecurityPolicy = "default-src 'self'"
|
||||
}
|
||||
|
||||
## Allowed hosts filter configuration
|
||||
# https://www.playframework.com/documentation/latest/AllowedHostsFilter
|
||||
# ~~~~~
|
||||
# Play provides a filter that lets you configure which hosts can access your application.
|
||||
# This is useful to prevent cache poisoning attacks.
|
||||
hosts {
|
||||
# Allow requests to example.com, its subdomains, and localhost:9000.
|
||||
#allowed = [".example.com", "localhost:9000"]
|
||||
}
|
||||
}
|
||||
|
||||
## Evolutions
|
||||
# https://www.playframework.com/documentation/latest/Evolutions
|
||||
# ~~~~~
|
||||
# Evolutions allows database scripts to be automatically run on startup in dev mode
|
||||
# for database migrations. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += evolutions
|
||||
#
|
||||
play.evolutions {
|
||||
# You can disable evolutions for a specific datasource if necessary
|
||||
#db.default.enabled = false
|
||||
}
|
||||
|
||||
## Database Connection Pool
|
||||
# https://www.playframework.com/documentation/latest/SettingsJDBC
|
||||
# ~~~~~
|
||||
# Play doesn't require a JDBC database to run, but you can easily enable one.
|
||||
#
|
||||
# libraryDependencies += jdbc
|
||||
#
|
||||
play.db {
|
||||
# The combination of these two settings results in "db.default" as the
|
||||
# default JDBC pool:
|
||||
#config = "db"
|
||||
#default = "default"
|
||||
|
||||
# Play uses HikariCP as the default connection pool. You can override
|
||||
# settings by changing the prototype:
|
||||
prototype {
|
||||
# Sets a fixed JDBC connection pool size of 50
|
||||
#hikaricp.minimumIdle = 50
|
||||
#hikaricp.maximumPoolSize = 50
|
||||
}
|
||||
}
|
||||
|
||||
## JDBC Datasource
|
||||
# https://www.playframework.com/documentation/latest/JavaDatabase
|
||||
# https://www.playframework.com/documentation/latest/ScalaDatabase
|
||||
# ~~~~~
|
||||
# Once JDBC datasource is set up, you can work with several different
|
||||
# database options:
|
||||
#
|
||||
# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
|
||||
# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
|
||||
# EBean: https://playframework.com/documentation/latest/JavaEbean
|
||||
# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
|
||||
#
|
||||
db {
|
||||
# You can declare as many datasources as you want.
|
||||
# By convention, the default datasource is named `default`
|
||||
|
||||
# https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
|
||||
default.driver = org.h2.Driver
|
||||
default.url = "jdbc:h2:mem:play"
|
||||
#default.username = sa
|
||||
#default.password = ""
|
||||
|
||||
# You can expose this datasource via JNDI if needed (Useful for JPA)
|
||||
default.jndiName=DefaultDS
|
||||
|
||||
# You can turn on SQL logging for any datasource
|
||||
# https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
|
||||
#default.logSql=true
|
||||
}
|
||||
|
||||
jpa.default=defaultPersistenceUnit
|
||||
|
||||
|
||||
#Increase default maximum post length - used for remote listener functionality
|
||||
#Can get response 413 with larger networks without setting this
|
||||
# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
|
||||
#parsers.text.maxLength=10M
|
||||
play.http.parser.maxMemoryBuffer=10M
|
|
@ -28,61 +28,11 @@
|
|||
<artifactId>datavec-spark_2.11</artifactId>
|
||||
|
||||
<properties>
|
||||
<!-- These Spark versions have to be here, and NOT in the datavec parent pom. Otherwise, the properties
|
||||
will be resolved to their defaults. Whereas when defined here (the version outputStream spark 1 vs. 2 specific) we can
|
||||
have different version properties simultaneously (in different pom files) instead of a single global property -->
|
||||
<spark.version>2.1.0</spark.version>
|
||||
<spark.major.version>2</spark.major.version>
|
||||
|
||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
||||
<scala.version>2.11.12</scala.version>
|
||||
<scala.binary.version>2.11</scala.binary.version>
|
||||
</properties>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<!-- added source folder containing the code specific to the spark version -->
|
||||
<plugin>
|
||||
<groupId>org.codehaus.mojo</groupId>
|
||||
<artifactId>build-helper-maven-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>add-source</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>add-source</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<sources>
|
||||
<source>src/main/spark-${spark.major.version}</source>
|
||||
</sources>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.datatype</groupId>
|
||||
<artifactId>jackson-datatype-jsr310</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||
<artifactId>jackson-dataformat-yaml</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.module</groupId>
|
||||
<artifactId>jackson-module-scala_2.11</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
|
@ -95,42 +45,13 @@
|
|||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.codehaus.jackson</groupId>
|
||||
<artifactId>jackson-core-asl</artifactId>
|
||||
<version>${jackson-asl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.codehaus.jackson</groupId>
|
||||
<artifactId>jackson-mapper-asl</artifactId>
|
||||
<version>${jackson-asl.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_2.11</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>${guava.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.inject</groupId>
|
||||
<artifactId>guice</artifactId>
|
||||
<version>${guice.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>${google.protobuf.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-codec</groupId>
|
||||
<artifactId>commons-codec</artifactId>
|
||||
<version>${commons-codec.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-collections</groupId>
|
||||
<artifactId>commons-collections</artifactId>
|
||||
|
@ -141,96 +62,16 @@
|
|||
<artifactId>commons-io</artifactId>
|
||||
<version>${commons-io.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-lang</groupId>
|
||||
<artifactId>commons-lang</artifactId>
|
||||
<version>${commons-lang.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-net</groupId>
|
||||
<artifactId>commons-net</artifactId>
|
||||
<version>${commons-net.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-core</artifactId>
|
||||
<version>${jaxb.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-impl</artifactId>
|
||||
<version>${jaxb.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-actor_2.11</artifactId>
|
||||
<version>${akka.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-remote_2.11</artifactId>
|
||||
<version>${akka.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-slf4j_2.11</artifactId>
|
||||
<version>${akka.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty</artifactId>
|
||||
<version>${netty.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-core</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-annotations</artifactId>
|
||||
<version>${jackson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>javax.servlet</groupId>
|
||||
<artifactId>javax.servlet-api</artifactId>
|
||||
<version>${servlet.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-compress</artifactId>
|
||||
<version>${commons-compress.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>${commons-lang3.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-math3</artifactId>
|
||||
<version>${commons-math3.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.curator</groupId>
|
||||
<artifactId>curator-recipes</artifactId>
|
||||
<version>${curator.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>${slf4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe</groupId>
|
||||
<artifactId>config</artifactId>
|
||||
<version>${typesafe.config.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_2.11</artifactId>
|
||||
|
@ -241,14 +82,6 @@
|
|||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-log4j12</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>log4j</groupId>
|
||||
<artifactId>log4j</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
|
@ -281,13 +114,6 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-local</artifactId>
|
||||
|
|
|
@ -21,10 +21,7 @@ import org.apache.spark.api.java.JavaRDD;
|
|||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.api.java.function.Function2;
|
||||
import org.apache.spark.sql.Column;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
import org.apache.spark.sql.functions;
|
||||
import org.apache.spark.sql.*;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.Metadata;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
|
@ -46,7 +43,6 @@ import java.util.List;
|
|||
|
||||
import static org.apache.spark.sql.functions.avg;
|
||||
import static org.apache.spark.sql.functions.col;
|
||||
import static org.datavec.spark.transform.DataRowsFacade.dataRows;
|
||||
|
||||
|
||||
/**
|
||||
|
@ -71,7 +67,7 @@ public class DataFrames {
|
|||
* deviation for
|
||||
* @return the column that represents the standard deviation
|
||||
*/
|
||||
public static Column std(DataRowsFacade dataFrame, String columnName) {
|
||||
public static Column std(Dataset<Row> dataFrame, String columnName) {
|
||||
return functions.sqrt(var(dataFrame, columnName));
|
||||
}
|
||||
|
||||
|
@ -85,8 +81,8 @@ public class DataFrames {
|
|||
* deviation for
|
||||
* @return the column that represents the standard deviation
|
||||
*/
|
||||
public static Column var(DataRowsFacade dataFrame, String columnName) {
|
||||
return dataFrame.get().groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
|
||||
public static Column var(Dataset<Row> dataFrame, String columnName) {
|
||||
return dataFrame.groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -97,8 +93,8 @@ public class DataFrames {
|
|||
* @param columnName the name of the column to get the min for
|
||||
* @return the column that represents the min
|
||||
*/
|
||||
public static Column min(DataRowsFacade dataFrame, String columnName) {
|
||||
return dataFrame.get().groupBy(columnName).agg(functions.min(columnName)).col(columnName);
|
||||
public static Column min(Dataset<Row> dataFrame, String columnName) {
|
||||
return dataFrame.groupBy(columnName).agg(functions.min(columnName)).col(columnName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -110,8 +106,8 @@ public class DataFrames {
|
|||
* to get the max for
|
||||
* @return the column that represents the max
|
||||
*/
|
||||
public static Column max(DataRowsFacade dataFrame, String columnName) {
|
||||
return dataFrame.get().groupBy(columnName).agg(functions.max(columnName)).col(columnName);
|
||||
public static Column max(Dataset<Row> dataFrame, String columnName) {
|
||||
return dataFrame.groupBy(columnName).agg(functions.max(columnName)).col(columnName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -122,8 +118,8 @@ public class DataFrames {
|
|||
* @param columnName the name of the column to get the mean for
|
||||
* @return the column that represents the mean
|
||||
*/
|
||||
public static Column mean(DataRowsFacade dataFrame, String columnName) {
|
||||
return dataFrame.get().groupBy(columnName).agg(avg(columnName)).col(columnName);
|
||||
public static Column mean(Dataset<Row> dataFrame, String columnName) {
|
||||
return dataFrame.groupBy(columnName).agg(avg(columnName)).col(columnName);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -166,7 +162,7 @@ public class DataFrames {
|
|||
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
|
||||
* of this record in the original time series.<br>
|
||||
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
|
||||
* using {@link #toRecordsSequence(DataRowsFacade)}
|
||||
* using {@link #toRecordsSequence(Dataset<Row>)}
|
||||
*
|
||||
* @param schema Schema to convert
|
||||
* @return StructType for the schema
|
||||
|
@ -250,9 +246,9 @@ public class DataFrames {
|
|||
* @param dataFrame the dataframe to convert
|
||||
* @return the converted schema and rdd of writables
|
||||
*/
|
||||
public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(DataRowsFacade dataFrame) {
|
||||
Schema schema = fromStructType(dataFrame.get().schema());
|
||||
return new Pair<>(schema, dataFrame.get().javaRDD().map(new ToRecord(schema)));
|
||||
public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(Dataset<Row> dataFrame) {
|
||||
Schema schema = fromStructType(dataFrame.schema());
|
||||
return new Pair<>(schema, dataFrame.javaRDD().map(new ToRecord(schema)));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -267,11 +263,11 @@ public class DataFrames {
|
|||
* @param dataFrame Data frame to convert
|
||||
* @return Data in sequence (i.e., {@code List<List<Writable>>} form
|
||||
*/
|
||||
public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(DataRowsFacade dataFrame) {
|
||||
public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(Dataset<Row> dataFrame) {
|
||||
|
||||
//Need to convert from flattened to sequence data...
|
||||
//First: Group by the Sequence UUID (first column)
|
||||
JavaPairRDD<String, Iterable<Row>> grouped = dataFrame.get().javaRDD().groupBy(new Function<Row, String>() {
|
||||
JavaPairRDD<String, Iterable<Row>> grouped = dataFrame.javaRDD().groupBy(new Function<Row, String>() {
|
||||
@Override
|
||||
public String call(Row row) throws Exception {
|
||||
return row.getString(0);
|
||||
|
@ -279,7 +275,7 @@ public class DataFrames {
|
|||
});
|
||||
|
||||
|
||||
Schema schema = fromStructType(dataFrame.get().schema());
|
||||
Schema schema = fromStructType(dataFrame.schema());
|
||||
|
||||
//Group by sequence UUID, and sort each row within the sequences using the time step index
|
||||
Function<Iterable<Row>, List<List<Writable>>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner
|
||||
|
@ -318,11 +314,11 @@ public class DataFrames {
|
|||
* @param data the data to convert
|
||||
* @return the dataframe object
|
||||
*/
|
||||
public static DataRowsFacade toDataFrame(Schema schema, JavaRDD<List<Writable>> data) {
|
||||
public static Dataset<Row> toDataFrame(Schema schema, JavaRDD<List<Writable>> data) {
|
||||
JavaSparkContext sc = new JavaSparkContext(data.context());
|
||||
SQLContext sqlContext = new SQLContext(sc);
|
||||
JavaRDD<Row> rows = data.map(new ToRow(schema));
|
||||
return dataRows(sqlContext.createDataFrame(rows, fromSchema(schema)));
|
||||
return sqlContext.createDataFrame(rows, fromSchema(schema));
|
||||
}
|
||||
|
||||
|
||||
|
@ -333,18 +329,18 @@ public class DataFrames {
|
|||
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
|
||||
* of this record in the original time series.<br>
|
||||
* These two columns are required if the data is to be converted back into a sequence at a later point, for example
|
||||
* using {@link #toRecordsSequence(DataRowsFacade)}
|
||||
* using {@link #toRecordsSequence(Dataset<Row>)}
|
||||
*
|
||||
* @param schema Schema for the data
|
||||
* @param data Sequence data to convert to a DataFrame
|
||||
* @return The dataframe object
|
||||
*/
|
||||
public static DataRowsFacade toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
|
||||
public static Dataset<Row> toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
|
||||
JavaSparkContext sc = new JavaSparkContext(data.context());
|
||||
|
||||
SQLContext sqlContext = new SQLContext(sc);
|
||||
JavaRDD<Row> rows = data.flatMap(new SequenceToRows(schema));
|
||||
return dataRows(sqlContext.createDataFrame(rows, fromSchemaSequence(schema)));
|
||||
return sqlContext.createDataFrame(rows, fromSchemaSequence(schema));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -19,14 +19,13 @@ package org.datavec.spark.transform;
|
|||
import org.apache.commons.collections.map.ListOrderedMap;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.sql.Column;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.datavec.spark.transform.DataRowsFacade.dataRows;
|
||||
|
||||
|
||||
/**
|
||||
* Simple dataframe based normalization.
|
||||
|
@ -46,7 +45,7 @@ public class Normalization {
|
|||
* @return a zero mean unit variance centered
|
||||
* rdd
|
||||
*/
|
||||
public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame) {
|
||||
public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame) {
|
||||
return zeromeanUnitVariance(frame, Collections.<String>emptyList());
|
||||
}
|
||||
|
||||
|
@ -71,7 +70,7 @@ public class Normalization {
|
|||
* @param max the maximum value
|
||||
* @return the normalized dataframe per column
|
||||
*/
|
||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max) {
|
||||
public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max) {
|
||||
return normalize(dataFrame, min, max, Collections.<String>emptyList());
|
||||
}
|
||||
|
||||
|
@ -86,7 +85,7 @@ public class Normalization {
|
|||
*/
|
||||
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min,
|
||||
double max) {
|
||||
DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
|
||||
Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
|
||||
return DataFrames.toRecords(normalize(frame, min, max, Collections.<String>emptyList())).getSecond();
|
||||
}
|
||||
|
||||
|
@ -97,7 +96,7 @@ public class Normalization {
|
|||
* @param dataFrame the dataframe to scale
|
||||
* @return the normalized dataframe per column
|
||||
*/
|
||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame) {
|
||||
public static Dataset<Row> normalize(Dataset<Row> dataFrame) {
|
||||
return normalize(dataFrame, 0, 1, Collections.<String>emptyList());
|
||||
}
|
||||
|
||||
|
@ -120,8 +119,8 @@ public class Normalization {
|
|||
* @return a zero mean unit variance centered
|
||||
* rdd
|
||||
*/
|
||||
public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame, List<String> skipColumns) {
|
||||
List<String> columnsList = DataFrames.toList(frame.get().columns());
|
||||
public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame, List<String> skipColumns) {
|
||||
List<String> columnsList = DataFrames.toList(frame.columns());
|
||||
columnsList.removeAll(skipColumns);
|
||||
String[] columnNames = DataFrames.toArray(columnsList);
|
||||
//first row is std second row is mean, each column in a row is for a particular column
|
||||
|
@ -133,7 +132,7 @@ public class Normalization {
|
|||
if (std == 0.0)
|
||||
std = 1; //All same value -> (x-x)/1 = 0
|
||||
|
||||
frame = dataRows(frame.get().withColumn(columnName, frame.get().col(columnName).minus(mean).divide(std)));
|
||||
frame = frame.withColumn(columnName, frame.col(columnName).minus(mean).divide(std));
|
||||
}
|
||||
|
||||
|
||||
|
@ -152,7 +151,7 @@ public class Normalization {
|
|||
*/
|
||||
public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> data,
|
||||
List<String> skipColumns) {
|
||||
DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
|
||||
Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
|
||||
return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond();
|
||||
}
|
||||
|
||||
|
@ -178,7 +177,7 @@ public class Normalization {
|
|||
*/
|
||||
public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema,
|
||||
JavaRDD<List<List<Writable>>> sequence, List<String> excludeColumns) {
|
||||
DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, sequence);
|
||||
Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, sequence);
|
||||
if (excludeColumns == null)
|
||||
excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN);
|
||||
else {
|
||||
|
@ -196,7 +195,7 @@ public class Normalization {
|
|||
* @param columns the columns to get the
|
||||
* @return
|
||||
*/
|
||||
public static List<Row> minMaxColumns(DataRowsFacade data, List<String> columns) {
|
||||
public static List<Row> minMaxColumns(Dataset<Row> data, List<String> columns) {
|
||||
String[] arr = new String[columns.size()];
|
||||
for (int i = 0; i < arr.length; i++)
|
||||
arr[i] = columns.get(i);
|
||||
|
@ -210,7 +209,7 @@ public class Normalization {
|
|||
* @param columns the columns to get the
|
||||
* @return
|
||||
*/
|
||||
public static List<Row> minMaxColumns(DataRowsFacade data, String... columns) {
|
||||
public static List<Row> minMaxColumns(Dataset<Row> data, String... columns) {
|
||||
return aggregate(data, columns, new String[] {"min", "max"});
|
||||
}
|
||||
|
||||
|
@ -221,7 +220,7 @@ public class Normalization {
|
|||
* @param columns the columns to get the
|
||||
* @return
|
||||
*/
|
||||
public static List<Row> stdDevMeanColumns(DataRowsFacade data, List<String> columns) {
|
||||
public static List<Row> stdDevMeanColumns(Dataset<Row> data, List<String> columns) {
|
||||
String[] arr = new String[columns.size()];
|
||||
for (int i = 0; i < arr.length; i++)
|
||||
arr[i] = columns.get(i);
|
||||
|
@ -237,7 +236,7 @@ public class Normalization {
|
|||
* @param columns the columns to get the
|
||||
* @return
|
||||
*/
|
||||
public static List<Row> stdDevMeanColumns(DataRowsFacade data, String... columns) {
|
||||
public static List<Row> stdDevMeanColumns(Dataset<Row> data, String... columns) {
|
||||
return aggregate(data, columns, new String[] {"stddev", "mean"});
|
||||
}
|
||||
|
||||
|
@ -251,7 +250,7 @@ public class Normalization {
|
|||
* Each row will be a function with the desired columnar output
|
||||
* in the order in which the columns were specified.
|
||||
*/
|
||||
public static List<Row> aggregate(DataRowsFacade data, String[] columns, String[] functions) {
|
||||
public static List<Row> aggregate(Dataset<Row> data, String[] columns, String[] functions) {
|
||||
String[] rest = new String[columns.length - 1];
|
||||
System.arraycopy(columns, 1, rest, 0, rest.length);
|
||||
List<Row> rows = new ArrayList<>();
|
||||
|
@ -262,8 +261,8 @@ public class Normalization {
|
|||
}
|
||||
|
||||
//compute the aggregation based on the operation
|
||||
DataRowsFacade aggregated = dataRows(data.get().agg(expressions));
|
||||
String[] columns2 = aggregated.get().columns();
|
||||
Dataset<Row> aggregated = data.agg(expressions);
|
||||
String[] columns2 = aggregated.columns();
|
||||
//strip out the op name and parentheses from the columns
|
||||
Map<String, String> opReplace = new TreeMap<>();
|
||||
for (String s : columns2) {
|
||||
|
@ -278,20 +277,20 @@ public class Normalization {
|
|||
|
||||
|
||||
//get rid of the operation name in the column
|
||||
DataRowsFacade rearranged = null;
|
||||
Dataset<Row> rearranged = null;
|
||||
for (Map.Entry<String, String> entries : opReplace.entrySet()) {
|
||||
//first column
|
||||
if (rearranged == null) {
|
||||
rearranged = dataRows(aggregated.get().withColumnRenamed(entries.getKey(), entries.getValue()));
|
||||
rearranged = aggregated.withColumnRenamed(entries.getKey(), entries.getValue());
|
||||
}
|
||||
//rearranged is just a copy of aggregated at this point
|
||||
else
|
||||
rearranged = dataRows(rearranged.get().withColumnRenamed(entries.getKey(), entries.getValue()));
|
||||
rearranged = rearranged.withColumnRenamed(entries.getKey(), entries.getValue());
|
||||
}
|
||||
|
||||
rearranged = dataRows(rearranged.get().select(DataFrames.toColumns(columns)));
|
||||
rearranged = rearranged.select(DataFrames.toColumns(columns));
|
||||
//op
|
||||
rows.addAll(rearranged.get().collectAsList());
|
||||
rows.addAll(rearranged.collectAsList());
|
||||
}
|
||||
|
||||
|
||||
|
@ -307,8 +306,8 @@ public class Normalization {
|
|||
* @param max the maximum value
|
||||
* @return the normalized dataframe per column
|
||||
*/
|
||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max, List<String> skipColumns) {
|
||||
List<String> columnsList = DataFrames.toList(dataFrame.get().columns());
|
||||
public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max, List<String> skipColumns) {
|
||||
List<String> columnsList = DataFrames.toList(dataFrame.columns());
|
||||
columnsList.removeAll(skipColumns);
|
||||
String[] columnNames = DataFrames.toArray(columnsList);
|
||||
//first row is min second row is max, each column in a row is for a particular column
|
||||
|
@ -321,8 +320,8 @@ public class Normalization {
|
|||
if (maxSubMin == 0)
|
||||
maxSubMin = 1;
|
||||
|
||||
Column newCol = dataFrame.get().col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min);
|
||||
dataFrame = dataRows(dataFrame.get().withColumn(columnName, newCol));
|
||||
Column newCol = dataFrame.col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min);
|
||||
dataFrame = dataFrame.withColumn(columnName, newCol);
|
||||
}
|
||||
|
||||
|
||||
|
@ -340,7 +339,7 @@ public class Normalization {
|
|||
*/
|
||||
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, double max,
|
||||
List<String> skipColumns) {
|
||||
DataRowsFacade frame = DataFrames.toDataFrame(schema, data);
|
||||
Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
|
||||
return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond();
|
||||
}
|
||||
|
||||
|
@ -387,7 +386,7 @@ public class Normalization {
|
|||
excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN);
|
||||
excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN);
|
||||
}
|
||||
DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, data);
|
||||
Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, data);
|
||||
return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond();
|
||||
}
|
||||
|
||||
|
@ -398,7 +397,7 @@ public class Normalization {
|
|||
* @param dataFrame the dataframe to scale
|
||||
* @return the normalized dataframe per column
|
||||
*/
|
||||
public static DataRowsFacade normalize(DataRowsFacade dataFrame, List<String> skipColumns) {
|
||||
public static Dataset<Row> normalize(Dataset<Row> dataFrame, List<String> skipColumns) {
|
||||
return normalize(dataFrame, 0, 1, skipColumns);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
|
||||
package org.datavec.spark.transform.analysis;
|
||||
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -27,10 +28,11 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SequenceFlatMapFunction extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, List<Writable>> {
|
||||
public class SequenceFlatMapFunction implements FlatMapFunction<List<List<Writable>>, List<Writable>> {
|
||||
|
||||
public SequenceFlatMapFunction() {
|
||||
super(new SequenceFlatMapFunctionAdapter());
|
||||
@Override
|
||||
public Iterator<List<Writable>> call(List<List<Writable>> collections) throws Exception {
|
||||
return collections.iterator();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform.analysis;
|
||||
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* SequenceFlatMapFunction: very simple function used to flatten a sequence
|
||||
* Typically used only internally for certain analysis operations
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SequenceFlatMapFunctionAdapter implements FlatMapFunctionAdapter<List<List<Writable>>, List<Writable>> {
|
||||
@Override
|
||||
public Iterable<List<Writable>> call(List<List<Writable>> collections) throws Exception {
|
||||
return collections;
|
||||
}
|
||||
|
||||
}
|
|
@ -16,11 +16,14 @@
|
|||
|
||||
package org.datavec.spark.transform.join;
|
||||
|
||||
import org.nd4j.shade.guava.collect.Iterables;
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.datavec.api.transform.join.Join;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
||||
import scala.Tuple2;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -28,10 +31,89 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ExecuteJoinFromCoGroupFlatMapFunction extends
|
||||
BaseFlatMapFunctionAdaptee<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
|
||||
public class ExecuteJoinFromCoGroupFlatMapFunction implements FlatMapFunction<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
|
||||
|
||||
private final Join join;
|
||||
|
||||
public ExecuteJoinFromCoGroupFlatMapFunction(Join join) {
|
||||
super(new ExecuteJoinFromCoGroupFlatMapFunctionAdapter(join));
|
||||
this.join = join;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<List<Writable>> call(
|
||||
Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>> t2)
|
||||
throws Exception {
|
||||
|
||||
Iterable<List<Writable>> leftList = t2._2()._1();
|
||||
Iterable<List<Writable>> rightList = t2._2()._2();
|
||||
|
||||
List<List<Writable>> ret = new ArrayList<>();
|
||||
Join.JoinType jt = join.getJoinType();
|
||||
switch (jt) {
|
||||
case Inner:
|
||||
//Return records where key columns appear in BOTH
|
||||
//So if no values from left OR right: no return values
|
||||
for (List<Writable> jvl : leftList) {
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case LeftOuter:
|
||||
//Return all records from left, even if no corresponding right value (NullWritable in that case)
|
||||
for (List<Writable> jvl : leftList) {
|
||||
if (Iterables.size(rightList) == 0) {
|
||||
List<Writable> joined = join.joinExamples(jvl, null);
|
||||
ret.add(joined);
|
||||
} else {
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case RightOuter:
|
||||
//Return all records from right, even if no corresponding left value (NullWritable in that case)
|
||||
for (List<Writable> jvr : rightList) {
|
||||
if (Iterables.size(leftList) == 0) {
|
||||
List<Writable> joined = join.joinExamples(null, jvr);
|
||||
ret.add(joined);
|
||||
} else {
|
||||
for (List<Writable> jvl : leftList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case FullOuter:
|
||||
//Return all records, even if no corresponding left/right value (NullWritable in that case)
|
||||
if (Iterables.size(leftList) == 0) {
|
||||
//Only right values
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(null, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
} else if (Iterables.size(rightList) == 0) {
|
||||
//Only left values
|
||||
for (List<Writable> jvl : leftList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, null);
|
||||
ret.add(joined);
|
||||
}
|
||||
} else {
|
||||
//Records from both left and right
|
||||
for (List<Writable> jvl : leftList) {
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
return ret.iterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,119 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform.join;
|
||||
|
||||
import com.google.common.collect.Iterables;
|
||||
import org.datavec.api.transform.join.Join;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
import scala.Tuple2;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Execute a join
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ExecuteJoinFromCoGroupFlatMapFunctionAdapter implements
|
||||
FlatMapFunctionAdapter<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
|
||||
|
||||
private final Join join;
|
||||
|
||||
public ExecuteJoinFromCoGroupFlatMapFunctionAdapter(Join join) {
|
||||
this.join = join;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterable<List<Writable>> call(
|
||||
Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>> t2)
|
||||
throws Exception {
|
||||
|
||||
Iterable<List<Writable>> leftList = t2._2()._1();
|
||||
Iterable<List<Writable>> rightList = t2._2()._2();
|
||||
|
||||
List<List<Writable>> ret = new ArrayList<>();
|
||||
Join.JoinType jt = join.getJoinType();
|
||||
switch (jt) {
|
||||
case Inner:
|
||||
//Return records where key columns appear in BOTH
|
||||
//So if no values from left OR right: no return values
|
||||
for (List<Writable> jvl : leftList) {
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case LeftOuter:
|
||||
//Return all records from left, even if no corresponding right value (NullWritable in that case)
|
||||
for (List<Writable> jvl : leftList) {
|
||||
if (Iterables.size(rightList) == 0) {
|
||||
List<Writable> joined = join.joinExamples(jvl, null);
|
||||
ret.add(joined);
|
||||
} else {
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case RightOuter:
|
||||
//Return all records from right, even if no corresponding left value (NullWritable in that case)
|
||||
for (List<Writable> jvr : rightList) {
|
||||
if (Iterables.size(leftList) == 0) {
|
||||
List<Writable> joined = join.joinExamples(null, jvr);
|
||||
ret.add(joined);
|
||||
} else {
|
||||
for (List<Writable> jvl : leftList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case FullOuter:
|
||||
//Return all records, even if no corresponding left/right value (NullWritable in that case)
|
||||
if (Iterables.size(leftList) == 0) {
|
||||
//Only right values
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(null, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
} else if (Iterables.size(rightList) == 0) {
|
||||
//Only left values
|
||||
for (List<Writable> jvl : leftList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, null);
|
||||
ret.add(joined);
|
||||
}
|
||||
} else {
|
||||
//Records from both left and right
|
||||
for (List<Writable> jvl : leftList) {
|
||||
for (List<Writable> jvr : rightList) {
|
||||
List<Writable> joined = join.joinExamples(jvl, jvr);
|
||||
ret.add(joined);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
}
|
|
@ -16,10 +16,12 @@
|
|||
|
||||
package org.datavec.spark.transform.join;
|
||||
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.datavec.api.transform.join.Join;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -29,10 +31,43 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class FilterAndFlattenJoinedValues extends BaseFlatMapFunctionAdaptee<JoinedValue, List<Writable>> {
|
||||
public class FilterAndFlattenJoinedValues implements FlatMapFunction<JoinedValue, List<Writable>> {
|
||||
|
||||
private final Join.JoinType joinType;
|
||||
|
||||
public FilterAndFlattenJoinedValues(Join.JoinType joinType) {
|
||||
super(new FilterAndFlattenJoinedValuesAdapter(joinType));
|
||||
this.joinType = joinType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<List<Writable>> call(JoinedValue joinedValue) throws Exception {
|
||||
boolean keep;
|
||||
switch (joinType) {
|
||||
case Inner:
|
||||
//Only keep joined values where we have both left and right
|
||||
keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
|
||||
break;
|
||||
case LeftOuter:
|
||||
//Keep all values where left is not missing/null
|
||||
keep = joinedValue.isHaveLeft();
|
||||
break;
|
||||
case RightOuter:
|
||||
//Keep all values where right is not missing/null
|
||||
keep = joinedValue.isHaveRight();
|
||||
break;
|
||||
case FullOuter:
|
||||
//Keep all values
|
||||
keep = true;
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Unknown/not implemented join type: " + joinType);
|
||||
}
|
||||
|
||||
if (keep) {
|
||||
return Collections.singletonList(joinedValue.getValues()).iterator();
|
||||
} else {
|
||||
return Collections.emptyIterator();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform.join;
|
||||
|
||||
import org.datavec.api.transform.join.Join;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Doing two things here:
|
||||
* (a) filter out any unnecessary values, and
|
||||
* (b) extract the List<Writable> values from the JoinedValue
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class FilterAndFlattenJoinedValuesAdapter implements FlatMapFunctionAdapter<JoinedValue, List<Writable>> {
|
||||
|
||||
private final Join.JoinType joinType;
|
||||
|
||||
public FilterAndFlattenJoinedValuesAdapter(Join.JoinType joinType) {
|
||||
this.joinType = joinType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterable<List<Writable>> call(JoinedValue joinedValue) throws Exception {
|
||||
boolean keep;
|
||||
switch (joinType) {
|
||||
case Inner:
|
||||
//Only keep joined values where we have both left and right
|
||||
keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
|
||||
break;
|
||||
case LeftOuter:
|
||||
//Keep all values where left is not missing/null
|
||||
keep = joinedValue.isHaveLeft();
|
||||
break;
|
||||
case RightOuter:
|
||||
//Keep all values where right is not missing/null
|
||||
keep = joinedValue.isHaveRight();
|
||||
break;
|
||||
case FullOuter:
|
||||
//Keep all values
|
||||
keep = true;
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Unknown/not implemented join type: " + joinType);
|
||||
}
|
||||
|
||||
if (keep) {
|
||||
return Collections.singletonList(joinedValue.getValues());
|
||||
} else {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,21 +16,69 @@
|
|||
|
||||
package org.datavec.spark.transform.sparkfunction;
|
||||
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
||||
import org.datavec.spark.transform.DataFrames;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Convert a record to a row
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class SequenceToRows extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, Row> {
|
||||
public class SequenceToRows implements FlatMapFunction<List<List<Writable>>, Row> {
|
||||
|
||||
private Schema schema;
|
||||
private StructType structType;
|
||||
|
||||
public SequenceToRows(Schema schema) {
|
||||
super(new SequenceToRowsAdapter(schema));
|
||||
this.schema = schema;
|
||||
structType = DataFrames.fromSchemaSequence(schema);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Iterator<Row> call(List<List<Writable>> sequence) throws Exception {
|
||||
if (sequence.size() == 0)
|
||||
return Collections.emptyIterator();
|
||||
|
||||
String sequenceUUID = UUID.randomUUID().toString();
|
||||
|
||||
List<Row> out = new ArrayList<>(sequence.size());
|
||||
|
||||
int stepCount = 0;
|
||||
for (List<Writable> step : sequence) {
|
||||
Object[] values = new Object[step.size() + 2];
|
||||
values[0] = sequenceUUID;
|
||||
values[1] = stepCount++;
|
||||
for (int i = 0; i < step.size(); i++) {
|
||||
switch (schema.getColumnTypes().get(i)) {
|
||||
case Double:
|
||||
values[i + 2] = step.get(i).toDouble();
|
||||
break;
|
||||
case Integer:
|
||||
values[i + 2] = step.get(i).toInt();
|
||||
break;
|
||||
case Long:
|
||||
values[i + 2] = step.get(i).toLong();
|
||||
break;
|
||||
case Float:
|
||||
values[i + 2] = step.get(i).toFloat();
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException(
|
||||
"This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
|
||||
}
|
||||
}
|
||||
|
||||
Row row = new GenericRowWithSchema(values, structType);
|
||||
out.add(row);
|
||||
}
|
||||
|
||||
return out.iterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,87 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform.sparkfunction;
|
||||
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
import org.datavec.spark.transform.DataFrames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
/**
|
||||
* Convert a record to a row
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class SequenceToRowsAdapter implements FlatMapFunctionAdapter<List<List<Writable>>, Row> {
|
||||
|
||||
private Schema schema;
|
||||
private StructType structType;
|
||||
|
||||
public SequenceToRowsAdapter(Schema schema) {
|
||||
this.schema = schema;
|
||||
structType = DataFrames.fromSchemaSequence(schema);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Iterable<Row> call(List<List<Writable>> sequence) throws Exception {
|
||||
if (sequence.size() == 0)
|
||||
return Collections.emptyList();
|
||||
|
||||
String sequenceUUID = UUID.randomUUID().toString();
|
||||
|
||||
List<Row> out = new ArrayList<>(sequence.size());
|
||||
|
||||
int stepCount = 0;
|
||||
for (List<Writable> step : sequence) {
|
||||
Object[] values = new Object[step.size() + 2];
|
||||
values[0] = sequenceUUID;
|
||||
values[1] = stepCount++;
|
||||
for (int i = 0; i < step.size(); i++) {
|
||||
switch (schema.getColumnTypes().get(i)) {
|
||||
case Double:
|
||||
values[i + 2] = step.get(i).toDouble();
|
||||
break;
|
||||
case Integer:
|
||||
values[i + 2] = step.get(i).toInt();
|
||||
break;
|
||||
case Long:
|
||||
values[i + 2] = step.get(i).toLong();
|
||||
break;
|
||||
case Float:
|
||||
values[i + 2] = step.get(i).toFloat();
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException(
|
||||
"This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
|
||||
}
|
||||
}
|
||||
|
||||
Row row = new GenericRowWithSchema(values, structType);
|
||||
out.add(row);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
}
|
|
@ -16,19 +16,27 @@
|
|||
|
||||
package org.datavec.spark.transform.transform;
|
||||
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Created by Alex on 17/03/2016.
|
||||
*/
|
||||
public class SequenceSplitFunction extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, List<List<Writable>>> {
|
||||
public class SequenceSplitFunction implements FlatMapFunction<List<List<Writable>>, List<List<Writable>>> {
|
||||
|
||||
private final SequenceSplit split;
|
||||
|
||||
public SequenceSplitFunction(SequenceSplit split) {
|
||||
super(new SequenceSplitFunctionAdapter(split));
|
||||
this.split = split;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<List<List<Writable>>> call(List<List<Writable>> collections) throws Exception {
|
||||
return split.split(collections).iterator();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform.transform;
|
||||
|
||||
import org.datavec.api.transform.sequence.SequenceSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Created by Alex on 17/03/2016.
|
||||
*/
|
||||
public class SequenceSplitFunctionAdapter
|
||||
implements FlatMapFunctionAdapter<List<List<Writable>>, List<List<Writable>>> {
|
||||
|
||||
private final SequenceSplit split;
|
||||
|
||||
public SequenceSplitFunctionAdapter(SequenceSplit split) {
|
||||
this.split = split;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterable<List<List<Writable>>> call(List<List<Writable>> collections) throws Exception {
|
||||
return split.split(collections);
|
||||
}
|
||||
}
|
|
@ -16,19 +16,32 @@
|
|||
|
||||
package org.datavec.spark.transform.transform;
|
||||
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Spark function for executing a transform process
|
||||
*/
|
||||
public class SparkTransformProcessFunction extends BaseFlatMapFunctionAdaptee<List<Writable>, List<Writable>> {
|
||||
public class SparkTransformProcessFunction implements FlatMapFunction<List<Writable>, List<Writable>> {
|
||||
|
||||
private final TransformProcess transformProcess;
|
||||
|
||||
public SparkTransformProcessFunction(TransformProcess transformProcess) {
|
||||
super(new SparkTransformProcessFunctionAdapter(transformProcess));
|
||||
this.transformProcess = transformProcess;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<List<Writable>> call(List<Writable> v1) throws Exception {
|
||||
List<Writable> newList = transformProcess.execute(v1);
|
||||
if (newList == null)
|
||||
return Collections.emptyIterator(); //Example was filtered out
|
||||
else
|
||||
return Collections.singletonList(newList).iterator();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform.transform;
|
||||
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Spark function for executing a transform process
|
||||
*/
|
||||
public class SparkTransformProcessFunctionAdapter implements FlatMapFunctionAdapter<List<Writable>, List<Writable>> {
|
||||
|
||||
private final TransformProcess transformProcess;
|
||||
|
||||
public SparkTransformProcessFunctionAdapter(TransformProcess transformProcess) {
|
||||
this.transformProcess = transformProcess;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterable<List<Writable>> call(List<Writable> v1) throws Exception {
|
||||
List<Writable> newList = transformProcess.execute(v1);
|
||||
if (newList == null)
|
||||
return Collections.emptyList(); //Example was filtered out
|
||||
else
|
||||
return Collections.singletonList(newList);
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
|
||||
/**
|
||||
* FlatMapFunction adapter to
|
||||
* hide incompatibilities between Spark 1.x and Spark 2.x
|
||||
*
|
||||
* This class should be used instead of direct referral to FlatMapFunction
|
||||
*
|
||||
*/
|
||||
public class BaseFlatMapFunctionAdaptee<K, V> implements FlatMapFunction<K, V> {
|
||||
|
||||
protected final FlatMapFunctionAdapter<K, V> adapter;
|
||||
|
||||
public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter<K, V> adapter) {
|
||||
this.adapter = adapter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterable<V> call(K k) throws Exception {
|
||||
return adapter.call(k);
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
|
||||
/**
|
||||
* Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
|
||||
*
|
||||
* This class should be used instead of direct referral to DataFrame / Dataset
|
||||
*
|
||||
*/
|
||||
public class DataRowsFacade {
|
||||
|
||||
private final DataFrame df;
|
||||
|
||||
private DataRowsFacade(DataFrame df) {
|
||||
this.df = df;
|
||||
}
|
||||
|
||||
public static DataRowsFacade dataRows(DataFrame df) {
|
||||
return new DataRowsFacade(df);
|
||||
}
|
||||
|
||||
public DataFrame get() {
|
||||
return df;
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.apache.spark.api.java.function.FlatMapFunction;
|
||||
import org.datavec.spark.functions.FlatMapFunctionAdapter;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* FlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x
|
||||
*
|
||||
* This class should be used instead of direct referral to FlatMapFunction
|
||||
*
|
||||
*/
|
||||
public class BaseFlatMapFunctionAdaptee<K, V> implements FlatMapFunction<K, V> {
|
||||
|
||||
protected final FlatMapFunctionAdapter<K, V> adapter;
|
||||
|
||||
public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter<K, V> adapter) {
|
||||
this.adapter = adapter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<V> call(K k) throws Exception {
|
||||
return adapter.call(k).iterator();
|
||||
}
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
|
||||
/**
|
||||
* Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
|
||||
*
|
||||
* This class should be used instead of direct referral to DataFrame / Dataset
|
||||
*
|
||||
*/
|
||||
public class DataRowsFacade {
|
||||
|
||||
private final Dataset<Row> df;
|
||||
|
||||
private DataRowsFacade(Dataset<Row> df) {
|
||||
this.df = df;
|
||||
}
|
||||
|
||||
public static DataRowsFacade dataRows(Dataset<Row> df) {
|
||||
return new DataRowsFacade(df);
|
||||
}
|
||||
|
||||
public Dataset<Row> get() {
|
||||
return df;
|
||||
}
|
||||
}
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.datavec.spark.storage;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.spark.api.java.JavaPairRDD;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.datavec.api.writable.*;
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.datavec.spark.transform;
|
|||
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.sql.Column;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
|
@ -46,9 +48,9 @@ public class DataFramesTests extends BaseSparkTest {
|
|||
for (int i = 0; i < numColumns; i++)
|
||||
builder.addColumnDouble(String.valueOf(i));
|
||||
Schema schema = builder.build();
|
||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records));
|
||||
dataFrame.get().show();
|
||||
dataFrame.get().describe(DataFrames.toArray(schema.getColumnNames())).show();
|
||||
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records));
|
||||
dataFrame.show();
|
||||
dataFrame.describe(DataFrames.toArray(schema.getColumnNames())).show();
|
||||
// System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames()));
|
||||
// System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames()));
|
||||
|
||||
|
@ -77,12 +79,12 @@ public class DataFramesTests extends BaseSparkTest {
|
|||
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
|
||||
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
|
||||
|
||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||
dataFrame.get().show();
|
||||
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||
dataFrame.show();
|
||||
Column mean = DataFrames.mean(dataFrame, "0");
|
||||
Column std = DataFrames.std(dataFrame, "0");
|
||||
dataFrame.get().withColumn("0", dataFrame.get().col("0").minus(mean)).show();
|
||||
dataFrame.get().withColumn("0", dataFrame.get().col("0").divide(std)).show();
|
||||
dataFrame.withColumn("0", dataFrame.col("0").minus(mean)).show();
|
||||
dataFrame.withColumn("0", dataFrame.col("0").divide(std)).show();
|
||||
|
||||
/* DataFrame desc = dataFrame.describe(dataFrame.columns());
|
||||
dataFrame.show();
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.datavec.spark.transform;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
|
@ -24,11 +25,13 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -50,36 +53,35 @@ public class NormalizationTests extends BaseSparkTest {
|
|||
for (int i = 0; i < numColumns; i++)
|
||||
builder.addColumnDouble(String.valueOf(i));
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
INDArray arr = Nd4j.rand(DataType.FLOAT, 5, numColumns);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
List<Writable> record = new ArrayList<>(numColumns);
|
||||
data.add(record);
|
||||
for (int j = 0; j < numColumns; j++) {
|
||||
record.add(new DoubleWritable(1.0));
|
||||
record.add(new DoubleWritable(arr.getDouble(i, j)));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
INDArray arr = RecordConverter.toMatrix(data);
|
||||
|
||||
Schema schema = builder.build();
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
|
||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||
|
||||
//assert equivalent to the ndarray pre processing
|
||||
NormalizerStandardize standardScaler = new NormalizerStandardize();
|
||||
standardScaler.fit(new DataSet(arr.dup(), arr.dup()));
|
||||
INDArray standardScalered = arr.dup();
|
||||
standardScaler.transform(new DataSet(standardScalered, standardScalered));
|
||||
DataNormalization zeroToOne = new NormalizerMinMaxScaler();
|
||||
zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
|
||||
INDArray zeroToOnes = arr.dup();
|
||||
zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
|
||||
List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.get().columns());
|
||||
List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.columns());
|
||||
INDArray assertion = DataFrames.toMatrix(rows);
|
||||
//compare standard deviation
|
||||
assertTrue(standardScaler.getStd().equalsWithEps(assertion.getRow(0), 1e-1));
|
||||
INDArray expStd = arr.std(true, true, 0);
|
||||
INDArray std = assertion.getRow(0, true);
|
||||
assertTrue(expStd.equalsWithEps(std, 1e-3));
|
||||
//compare mean
|
||||
assertTrue(standardScaler.getMean().equalsWithEps(assertion.getRow(1), 1e-1));
|
||||
INDArray expMean = arr.mean(true, 0);
|
||||
assertTrue(expMean.equalsWithEps(assertion.getRow(1, true), 1e-3));
|
||||
|
||||
}
|
||||
|
||||
|
@ -109,10 +111,10 @@ public class NormalizationTests extends BaseSparkTest {
|
|||
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
|
||||
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
|
||||
|
||||
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||
dataFrame.get().show();
|
||||
Normalization.zeromeanUnitVariance(dataFrame).get().show();
|
||||
Normalization.normalize(dataFrame).get().show();
|
||||
Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
|
||||
dataFrame.show();
|
||||
Normalization.zeromeanUnitVariance(dataFrame).show();
|
||||
Normalization.normalize(dataFrame).show();
|
||||
|
||||
//assert equivalent to the ndarray pre processing
|
||||
NormalizerStandardize standardScaler = new NormalizerStandardize();
|
||||
|
|
|
@ -330,53 +330,6 @@
|
|||
</plugins>
|
||||
</build>
|
||||
|
||||
<profiles>
|
||||
<!-- For running unit tests with nd4j-native: "mvn clean test -P test-nd4j-native"
|
||||
Note that this excludes DL4J-cuda -->
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${dl4j-test-resources.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${dl4j-test-resources.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<!-- Default to ALL modules here, unlike nd4j-native -->
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
<reporting>
|
||||
<plugins>
|
||||
<plugin>
|
||||
|
@ -391,4 +344,42 @@
|
|||
</plugin>
|
||||
</plugins>
|
||||
</reporting>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -52,18 +52,6 @@ public class DL4JSystemProperties {
|
|||
*/
|
||||
public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl";
|
||||
|
||||
/**
|
||||
* Applicability: deeplearning4j-nn<br>
|
||||
* Description: Used for loading legacy format JSON containing custom layers. This system property is provided as an
|
||||
* alternative to {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}. Classes are specified in
|
||||
* comma-separated format.<br>
|
||||
* This is required ONLY when ALL of the following conditions are met:<br>
|
||||
* 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND<br>
|
||||
* 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J), AND<br>
|
||||
* 3. You haven't already called {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}
|
||||
*/
|
||||
public static final String CUSTOM_REGISTRATION_PROPERTY = "org.deeplearning4j.config.custom.legacyclasses";
|
||||
|
||||
/**
|
||||
* Applicability: deeplearning4j-nn<br>
|
||||
* Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled
|
||||
|
|
|
@ -29,23 +29,11 @@
|
|||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-common</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
@ -82,7 +70,6 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nn</artifactId>
|
||||
|
@ -109,18 +96,6 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>${guava.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
|
@ -132,8 +107,6 @@
|
|||
<version>${commonslang.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
|
||||
<!-- ND4J Shaded Jackson Dependency -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
@ -141,7 +114,6 @@
|
|||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
|
@ -180,9 +152,26 @@
|
|||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.1</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.datasets.datavec;
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.deeplearning4j.datasets.datavec;
|
||||
|
||||
|
||||
import com.google.common.io.Files;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.commons.compress.utils.IOUtils;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
|
|
|
@ -22,12 +22,16 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.fail;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* A set of tests to ensure that useful exceptions are thrown on invalid input
|
||||
|
@ -267,23 +271,44 @@ public class TestInvalidInput extends BaseDL4JTest {
|
|||
//Idea: Using rnnTimeStep with a different number of examples between calls
|
||||
//(i.e., not calling reset between time steps)
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
||||
.layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build())
|
||||
.layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
||||
for(String layerType : new String[]{"simple", "lstm", "graves"}) {
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
Layer l;
|
||||
switch (layerType){
|
||||
case "simple":
|
||||
l = new SimpleRnn.Builder().nIn(5).nOut(5).build();
|
||||
break;
|
||||
case "lstm":
|
||||
l = new LSTM.Builder().nIn(5).nOut(5).build();
|
||||
break;
|
||||
case "graves":
|
||||
l = new GravesLSTM.Builder().nIn(5).nOut(5).build();
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
net.rnnTimeStep(Nd4j.create(3, 5, 10));
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
|
||||
.layer(l)
|
||||
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
|
||||
|
||||
try {
|
||||
net.rnnTimeStep(Nd4j.create(5, 5, 10));
|
||||
fail("Expected DL4JException");
|
||||
} catch (DL4JException e) {
|
||||
System.out.println("testInvalidRnnTimeStep(): " + e.getMessage());
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
fail("Expected DL4JException");
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
net.rnnTimeStep(Nd4j.create(3, 5, 10));
|
||||
|
||||
Map<String, INDArray> m = net.rnnGetPreviousState(0);
|
||||
assertNotNull(m);
|
||||
assertFalse(m.isEmpty());
|
||||
|
||||
try {
|
||||
net.rnnTimeStep(Nd4j.create(5, 5, 10));
|
||||
fail("Expected Exception - " + layerType);
|
||||
} catch (Exception e) {
|
||||
// e.printStackTrace();
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue