Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
AlexDBlack 2019-08-31 12:33:22 +10:00
commit b393d3fdb1
531 changed files with 12458 additions and 11429 deletions

View File

@ -72,6 +72,11 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
<version>${jodatime.version}</version>
</dependency>
<!-- ND4J Shaded Jackson Dependency --> <!-- ND4J Shaded Jackson Dependency -->
<dependency> <dependency>
@ -80,4 +85,13 @@
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
</profile>
</profiles>
</project> </project>

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.arbiter.optimize.distribution; package org.deeplearning4j.arbiter.optimize.distribution;
import com.google.common.base.Preconditions; import org.nd4j.shade.guava.base.Preconditions;
import lombok.Getter; import lombok.Getter;
import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.exception.NumberIsTooLargeException; import org.apache.commons.math3.exception.NumberIsTooLargeException;

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.arbiter.optimize.runner; package org.deeplearning4j.arbiter.optimize.runner;
import com.google.common.util.concurrent.ListenableFuture; import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;

View File

@ -16,9 +16,9 @@
package org.deeplearning4j.arbiter.optimize.runner; package org.deeplearning4j.arbiter.optimize.runner;
import com.google.common.util.concurrent.ListenableFuture; import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService; import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors; import org.nd4j.shade.guava.util.concurrent.MoreExecutors;
import lombok.Setter; import lombok.Setter;
import org.deeplearning4j.arbiter.optimize.api.*; import org.deeplearning4j.arbiter.optimize.api.*;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;

View File

@ -43,13 +43,15 @@ public class JsonMapper {
mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
yamlMapper = new ObjectMapper(new YAMLFactory()); yamlMapper = new ObjectMapper(new YAMLFactory());
mapper.registerModule(new JodaModule()); yamlMapper.registerModule(new JodaModule());
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
mapper.enable(SerializationFeature.INDENT_OUTPUT); yamlMapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
} }
private JsonMapper() {} private JsonMapper() {}

View File

@ -39,6 +39,7 @@ public class YamlMapper {
mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
} }

View File

@ -59,6 +59,7 @@ public class TestJson {
om.enable(SerializationFeature.INDENT_OUTPUT); om.enable(SerializationFeature.INDENT_OUTPUT);
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
return om; return om;
} }

View File

@ -38,13 +38,6 @@
<version>${dl4j.version}</version> <version>${dl4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
@ -64,6 +57,20 @@
<artifactId>jackson</artifactId> <artifactId>jackson</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>${gson.version}</version>
</dependency>
</dependencies> </dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
</profile>
</profiles>
</project> </project>

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.arbiter.layers; package org.deeplearning4j.arbiter.layers;
import com.google.common.base.Preconditions; import org.nd4j.shade.guava.base.Preconditions;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;

View File

@ -49,11 +49,14 @@
<version>${junit.version}</version> <version>${junit.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
</profile>
</profiles>
</project> </project>

View File

@ -97,15 +97,16 @@
</plugins> </plugins>
</build> </build>
</profile> </profile>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
</profile>
</profiles> </profiles>
<dependencies> <dependencies>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-core</artifactId> <artifactId>arbiter-core</artifactId>
@ -124,13 +125,6 @@
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
@ -139,9 +133,6 @@
</dependency> </dependency>
</dependencies> </dependencies>
<build> <build>
<extensions> <extensions>
<extension> <extension>
@ -222,5 +213,4 @@
</plugins> </plugins>
</pluginManagement> </pluginManagement>
</build> </build>
</project> </project>

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.ui.misc;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect; import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.MapperFeature;
@ -45,12 +46,9 @@ public class JsonMapper {
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true); mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.enable(SerializationFeature.INDENT_OUTPUT);
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibilityChecker(mapper.getSerializationConfig().getDefaultVisibilityChecker() mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
.withFieldVisibility(JsonAutoDetect.Visibility.ANY) mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
.withCreatorVisibility(JsonAutoDetect.Visibility.NONE));
return mapper; return mapper;
} }

View File

@ -136,6 +136,31 @@
<pluginManagement> <pluginManagement>
<plugins> <plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
<version>${maven-enforcer-plugin.version}</version>
<executions>
<execution>
<phase>test</phase>
<id>enforce-test-resources</id>
<goals>
<goal>enforce</goal>
</goals>
<configuration>
<skip>${skipTestResourceEnforcement}</skip>
<rules>
<requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-10.1</profiles>
<all>false</all>
</requireActiveProfile>
</rules>
<fail>true</fail>
</configuration>
</execution>
</executions>
</plugin>
<plugin> <plugin>
<artifactId>maven-javadoc-plugin</artifactId> <artifactId>maven-javadoc-plugin</artifactId>
<version>${maven-javadoc-plugin.version}</version> <version>${maven-javadoc-plugin.version}</version>
@ -287,4 +312,42 @@
</plugin> </plugin>
</plugins> </plugins>
</reporting> </reporting>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project> </project>

View File

@ -20,9 +20,9 @@
set -e set -e
VALID_VERSIONS=( 2.10 2.11 ) VALID_VERSIONS=( 2.11 2.12 )
SCALA_210_VERSION=$(grep -F -m 1 'scala210.version' pom.xml); SCALA_210_VERSION="${SCALA_210_VERSION#*>}"; SCALA_210_VERSION="${SCALA_210_VERSION%<*}";
SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}"; SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}";
SCALA_212_VERSION=$(grep -F -m 1 'scala212.version' pom.xml); SCALA_212_VERSION="${SCALA_212_VERSION#*>}"; SCALA_212_VERSION="${SCALA_212_VERSION%<*}";
usage() { usage() {
echo "Usage: $(basename $0) [-h|--help] <scala version to be used> echo "Usage: $(basename $0) [-h|--help] <scala version to be used>
@ -45,19 +45,18 @@ check_scala_version() {
exit 1 exit 1
} }
check_scala_version "$TO_VERSION" check_scala_version "$TO_VERSION"
if [ $TO_VERSION = "2.11" ]; then if [ $TO_VERSION = "2.11" ]; then
FROM_BINARY="_2\.10" FROM_BINARY="_2\.12"
TO_BINARY="_2\.11" TO_BINARY="_2\.11"
FROM_VERSION=$SCALA_210_VERSION FROM_VERSION=$SCALA_212_VERSION
TO_VERSION=$SCALA_211_VERSION TO_VERSION=$SCALA_211_VERSION
else else
FROM_BINARY="_2\.11" FROM_BINARY="_2\.11"
TO_BINARY="_2\.10" TO_BINARY="_2\.12"
FROM_VERSION=$SCALA_211_VERSION FROM_VERSION=$SCALA_211_VERSION
TO_VERSION=$SCALA_210_VERSION TO_VERSION=$SCALA_212_VERSION
fi fi
sed_i() { sed_i() {
@ -70,35 +69,24 @@ echo "Updating Scala versions in pom.xml files to Scala $1, from $FROM_VERSION t
BASEDIR=$(dirname $0) BASEDIR=$(dirname $0)
#Artifact ids, ending with "_2.10" or "_2.11". Spark, spark-mllib, kafka, etc. #Artifact ids, ending with "_2.11" or "_2.12". Spark, spark-mllib, kafka, etc.
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \; -exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \;
#Scala versions, like <scala.version>2.10</scala.version> #Scala versions, like <scala.version>2.11</scala.version>
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \; -exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \;
#Scala binary versions, like <scala.binary.version>2.10</scala.binary.version> #Scala binary versions, like <scala.binary.version>2.11</scala.binary.version>
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \; -exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \;
#Scala versions, like <artifactId>scala-library</artifactId> <version>2.10.6</version> #Scala versions, like <artifactId>scala-library</artifactId> <version>2.11.12</version>
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \; -exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \;
#Scala maven plugin, <scalaVersion>2.10</scalaVersion> #Scala maven plugin, <scalaVersion>2.11</scalaVersion>
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \; -exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \;
#Edge case for Korean NLP artifact not following conventions: https://github.com/deeplearning4j/deeplearning4j/issues/6306
#https://github.com/deeplearning4j/deeplearning4j/issues/6306
if [[ $TO_VERSION == 2.11* ]]; then
sed_i 's/<artifactId>korean-text-scala-2.10<\/artifactId>/<artifactId>korean-text<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
sed_i 's/<version>4.2.0<\/version>/<version>4.4<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
else
sed_i 's/<artifactId>korean-text<\/artifactId>/<artifactId>korean-text-scala-2.10<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
sed_i 's/<version>4.4<\/version>/<version>4.2.0<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
fi
echo "Done updating Scala versions."; echo "Done updating Scala versions.";

View File

@ -1,85 +0,0 @@
#!/usr/bin/env bash
################################################################################
# Copyright (c) 2015-2018 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
# This shell script is adapted from Apache Flink (in turn, adapted from Apache Spark) some modifications.
set -e
VALID_VERSIONS=( 1 2 )
SPARK_2_VERSION="2\.1\.0"
SPARK_1_VERSION="1\.6\.3"
usage() {
echo "Usage: $(basename $0) [-h|--help] <spark version to be used>
where :
-h| --help Display this help text
valid spark version values : ${VALID_VERSIONS[*]}
" 1>&2
exit 1
}
if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then
usage
fi
TO_VERSION=$1
check_spark_version() {
for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done
echo "Invalid Spark version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2
exit 1
}
check_spark_version "$TO_VERSION"
if [ $TO_VERSION = "2" ]; then
FROM_BINARY="1"
TO_BINARY="2"
FROM_VERSION=$SPARK_1_VERSION
TO_VERSION=$SPARK_2_VERSION
else
FROM_BINARY="2"
TO_BINARY="1"
FROM_VERSION=$SPARK_2_VERSION
TO_VERSION=$SPARK_1_VERSION
fi
sed_i() {
sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2"
}
export -f sed_i
echo "Updating Spark versions in pom.xml files to Spark $1";
BASEDIR=$(dirname $0)
# <spark.major.version>1</spark.major.version>
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(spark.major.version>\)'$FROM_BINARY'<\/spark.major.version>/\1'$TO_BINARY'<\/spark.major.version>/g' {}" \;
# <spark.version>1.6.3</spark.version>
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(spark.version>\)'$FROM_VERSION'<\/spark.version>/\1'$TO_VERSION'<\/spark.version>/g' {}" \;
#Spark versions, like <version>xxx_spark_2xxx</version> OR <datavec.spark.version>xxx_spark_2xxx</datavec.spark.version>
find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \
-exec bash -c "sed_i 's/\(version>.*_spark_\)'$FROM_BINARY'\(.*\)version>/\1'$TO_BINARY'\2version>/g' {}" \;
echo "Done updating Spark versions.";

View File

@ -26,7 +26,6 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
@ -98,13 +97,6 @@
<version>${stream.analytics.version}</version> <version>${stream.analytics.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<!-- csv parser, same dep used by spark --> <!-- csv parser, same dep used by spark -->
<dependency> <dependency>
<groupId>net.sf.opencsv</groupId> <groupId>net.sf.opencsv</groupId>
@ -125,7 +117,6 @@
</dependency> </dependency>
</dependencies> </dependencies>
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>

View File

@ -16,7 +16,6 @@
package org.datavec.api.transform; package org.datavec.api.transform;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -27,8 +26,7 @@ import java.util.List;
/**A Transform converts an example to another example, or a sequence to another sequence /**A Transform converts an example to another example, or a sequence to another sequence
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.TransformHelper.class)
public interface Transform extends Serializable, ColumnOp { public interface Transform extends Serializable, ColumnOp {
/** /**

View File

@ -67,6 +67,7 @@ import org.joda.time.DateTimeZone;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
@ -417,6 +418,16 @@ public class TransformProcess implements Serializable {
public static TransformProcess fromJson(String json) { public static TransformProcess fromJson(String json) {
try { try {
return JsonMappers.getMapper().readValue(json, TransformProcess.class); return JsonMappers.getMapper().readValue(json, TransformProcess.class);
} catch (InvalidTypeIdException e){
if(e.getMessage().contains("@class")){
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
try{
return JsonMappers.getLegacyMapper().readValue(json, TransformProcess.class);
} catch (IOException e2){
throw new RuntimeException(e2);
}
}
throw new RuntimeException(e);
} catch (IOException e) { } catch (IOException e) {
//TODO proper exception message //TODO proper exception message
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@ -23,12 +23,14 @@ import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.metadata.CategoricalMetaData;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonMappers;
import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer; import org.datavec.api.transform.serde.YamlSerializer;
import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.databind.node.ArrayNode; import org.nd4j.shade.jackson.databind.node.ArrayNode;
import java.io.IOException; import java.io.IOException;
@ -116,6 +118,16 @@ public class DataAnalysis implements Serializable {
public static DataAnalysis fromJson(String json) { public static DataAnalysis fromJson(String json) {
try{ try{
return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class); return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class);
} catch (InvalidTypeIdException e){
if(e.getMessage().contains("@class")){
try{
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
return JsonMappers.getLegacyMapper().readValue(json, DataAnalysis.class);
} catch (IOException e2){
throw new RuntimeException(e2);
}
}
throw new RuntimeException(e);
} catch (Exception e){ } catch (Exception e){
//Legacy format //Legacy format
ObjectMapper om = new JsonSerializer().getObjectMapper(); ObjectMapper om = new JsonSerializer().getObjectMapper();

View File

@ -21,9 +21,10 @@ import lombok.EqualsAndHashCode;
import org.datavec.api.transform.analysis.columns.ColumnAnalysis; import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis; import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonMappers;
import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer; import org.datavec.api.transform.serde.YamlSerializer;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -50,6 +51,16 @@ public class SequenceDataAnalysis extends DataAnalysis {
public static SequenceDataAnalysis fromJson(String json){ public static SequenceDataAnalysis fromJson(String json){
try{ try{
return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class); return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class);
} catch (InvalidTypeIdException e){
if(e.getMessage().contains("@class")){
try{
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
return JsonMappers.getLegacyMapper().readValue(json, SequenceDataAnalysis.class);
} catch (IOException e2){
throw new RuntimeException(e2);
}
}
throw new RuntimeException(e);
} catch (IOException e){ } catch (IOException e){
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -17,7 +17,6 @@
package org.datavec.api.transform.analysis.columns; package org.datavec.api.transform.analysis.columns;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -27,8 +26,7 @@ import java.io.Serializable;
* Interface for column analysis * Interface for column analysis
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.ColumnAnalysisHelper.class)
public interface ColumnAnalysis extends Serializable { public interface ColumnAnalysis extends Serializable {
long getCountTotal(); long getCountTotal();

View File

@ -18,7 +18,6 @@ package org.datavec.api.transform.condition;
import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.ColumnOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -35,8 +34,7 @@ import java.util.List;
* @author Alex Black * @author Alex Black
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.ConditionHelper.class)
public interface Condition extends Serializable, ColumnOp { public interface Condition extends Serializable, ColumnOp {
/** /**

View File

@ -18,7 +18,6 @@ package org.datavec.api.transform.filter;
import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.ColumnOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -33,8 +32,7 @@ import java.util.List;
* @author Alex Black * @author Alex Black
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.FilterHelper.class)
public interface Filter extends Serializable, ColumnOp { public interface Filter extends Serializable, ColumnOp {
/** /**

View File

@ -17,7 +17,6 @@
package org.datavec.api.transform.metadata; package org.datavec.api.transform.metadata;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -32,8 +31,7 @@ import java.io.Serializable;
* @author Alex Black * @author Alex Black
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.ColumnMetaDataHelper.class)
public interface ColumnMetaData extends Serializable, Cloneable { public interface ColumnMetaData extends Serializable, Cloneable {
/** /**

View File

@ -23,8 +23,8 @@ import org.datavec.api.writable.Writable;
import java.util.List; import java.util.List;
import static com.google.common.base.Preconditions.checkArgument; import static org.nd4j.shade.guava.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull; import static org.nd4j.shade.guava.base.Preconditions.checkNotNull;
/** /**
* A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition}, * A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition},

View File

@ -23,7 +23,6 @@ import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.LongMetaData; import org.datavec.api.transform.metadata.LongMetaData;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.comparator.WritableComparator; import org.datavec.api.writable.comparator.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
@ -50,8 +49,7 @@ import java.util.List;
@EqualsAndHashCode(exclude = {"inputSchema"}) @EqualsAndHashCode(exclude = {"inputSchema"})
@JsonIgnoreProperties({"inputSchema"}) @JsonIgnoreProperties({"inputSchema"})
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.CalculateSortedRankHelper.class)
public class CalculateSortedRank implements Serializable, ColumnOp { public class CalculateSortedRank implements Serializable, ColumnOp {
private final String newColumnName; private final String newColumnName;

View File

@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.*; import org.datavec.api.transform.metadata.*;
import org.datavec.api.transform.serde.JsonMappers; import org.datavec.api.transform.serde.JsonMappers;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.nd4j.shade.jackson.annotation.*; import org.nd4j.shade.jackson.annotation.*;
@ -29,9 +28,11 @@ import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule; import org.nd4j.shade.jackson.datatype.joda.JodaModule;
import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
import java.util.*; import java.util.*;
@ -48,8 +49,7 @@ import java.util.*;
*/ */
@JsonIgnoreProperties({"columnNames", "columnNamesIndex"}) @JsonIgnoreProperties({"columnNames", "columnNamesIndex"})
@EqualsAndHashCode @EqualsAndHashCode
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.SchemaHelper.class)
@Data @Data
public class Schema implements Serializable { public class Schema implements Serializable {
@ -358,6 +358,16 @@ public class Schema implements Serializable {
public static Schema fromJson(String json) { public static Schema fromJson(String json) {
try{ try{
return JsonMappers.getMapper().readValue(json, Schema.class); return JsonMappers.getMapper().readValue(json, Schema.class);
} catch (InvalidTypeIdException e){
if(e.getMessage().contains("@class")){
try{
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
return JsonMappers.getLegacyMapper().readValue(json, Schema.class);
} catch (IOException e2){
throw new RuntimeException(e2);
}
}
throw new RuntimeException(e);
} catch (Exception e){ } catch (Exception e){
//TODO better exceptions //TODO better exceptions
throw new RuntimeException(e); throw new RuntimeException(e);
@ -379,21 +389,6 @@ public class Schema implements Serializable {
} }
} }
private static Schema fromJacksonString(String str, JsonFactory factory) {
ObjectMapper om = new ObjectMapper(factory);
om.registerModule(new JodaModule());
om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
om.enable(SerializationFeature.INDENT_OUTPUT);
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
try {
return om.readValue(str, Schema.class);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public static class Builder { public static class Builder {
List<ColumnMetaData> columnMetaData = new ArrayList<>(); List<ColumnMetaData> columnMetaData = new ArrayList<>();

View File

@ -17,7 +17,6 @@
package org.datavec.api.transform.sequence; package org.datavec.api.transform.sequence;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -30,8 +29,7 @@ import java.util.List;
* Compare the time steps of a sequence * Compare the time steps of a sequence
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.SequenceComparatorHelper.class)
public interface SequenceComparator extends Comparator<List<Writable>>, Serializable { public interface SequenceComparator extends Comparator<List<Writable>>, Serializable {
/** /**

View File

@ -17,7 +17,6 @@
package org.datavec.api.transform.sequence; package org.datavec.api.transform.sequence;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -32,8 +31,7 @@ import java.util.List;
* @author Alex Black * @author Alex Black
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.SequenceSplitHelper.class)
public interface SequenceSplit extends Serializable { public interface SequenceSplit extends Serializable {
/** /**

View File

@ -17,7 +17,6 @@
package org.datavec.api.transform.sequence.window; package org.datavec.api.transform.sequence.window;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -36,8 +35,7 @@ import java.util.List;
* @author Alex Black * @author Alex Black
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.WindowFunctionHelper.class)
public interface WindowFunction extends Serializable { public interface WindowFunction extends Serializable {
/** /**

View File

@ -16,44 +16,17 @@
package org.datavec.api.transform.serde; package org.datavec.api.transform.serde;
import lombok.AllArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.WritableComparator; import org.datavec.api.transform.serde.legacy.LegacyJsonFormat;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
import org.datavec.api.transform.condition.column.ColumnCondition;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.sequence.SequenceComparator;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.transform.sequence.window.WindowFunction;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.serde.json.LegacyIActivationDeserializer;
import org.nd4j.serde.json.LegacyILossFunctionDeserializer;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect; import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.annotation.PropertyAccessor; import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.databind.*; import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.cfg.MapperConfig; import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.introspect.Annotated; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass; import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.introspect.AnnotationMap;
import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector;
import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.nd4j.shade.jackson.datatype.joda.JodaModule; import org.nd4j.shade.jackson.datatype.joda.JodaModule;
import java.lang.annotation.Annotation;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/** /**
* JSON mappers for deserializing neural net configurations, etc. * JSON mappers for deserializing neural net configurations, etc.
* *
@ -62,38 +35,9 @@ import java.util.concurrent.ConcurrentHashMap;
@Slf4j @Slf4j
public class JsonMappers { public class JsonMappers {
/**
* This system property is provided as an alternative to {@link #registerLegacyCustomClassesForJSON(Class[])}
* Classes can be specified in comma-separated format
*/
public static String CUSTOM_REGISTRATION_PROPERTY = "org.datavec.config.custom.legacyclasses";
static {
String p = System.getProperty(CUSTOM_REGISTRATION_PROPERTY);
if(p != null && !p.isEmpty()){
String[] split = p.split(",");
List<Class<?>> list = new ArrayList<>();
for(String s : split){
try{
Class<?> c = Class.forName(s);
list.add(c);
} catch (Throwable t){
log.warn("Error parsing {} system property: class \"{}\" could not be loaded",CUSTOM_REGISTRATION_PROPERTY, s, t);
}
}
if(list.size() > 0){
try {
registerLegacyCustomClassesForJSONList(list);
} catch (Throwable t){
log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",CUSTOM_REGISTRATION_PROPERTY, t);
}
}
}
}
private static ObjectMapper jsonMapper; private static ObjectMapper jsonMapper;
private static ObjectMapper yamlMapper; private static ObjectMapper yamlMapper;
private static ObjectMapper legacyMapper; //For 1.0.0-alpha and earlier TransformProcess etc
static { static {
jsonMapper = new ObjectMapper(); jsonMapper = new ObjectMapper();
@ -102,117 +46,12 @@ public class JsonMappers {
configureMapper(yamlMapper); configureMapper(yamlMapper);
} }
private static Map<Class, ObjectMapper> legacyMappers = new ConcurrentHashMap<>(); public static synchronized ObjectMapper getLegacyMapper(){
if(legacyMapper == null){
legacyMapper = LegacyJsonFormat.legacyMapper();
/** configureMapper(legacyMapper);
* Register a set of classes (Transform, Filter, etc) for JSON deserialization.<br>
* <br>
* This is required ONLY when BOTH of the following conditions are met:<br>
* 1. You want to load a serialized TransformProcess, saved in 1.0.0-alpha or before, AND<br>
* 2. The serialized TransformProcess has a custom Transform, Filter, etc (i.e., one not defined in DL4J)<br>
* <br>
* By passing the classes of these custom classes here, DataVec should be able to deserialize them, in spite of the JSON
* format change between versions.
*
* @param classes Classes to register
*/
public static void registerLegacyCustomClassesForJSON(Class<?>... classes) {
registerLegacyCustomClassesForJSONList(Arrays.<Class<?>>asList(classes));
}
/**
* @see #registerLegacyCustomClassesForJSON(Class[])
*/
public static void registerLegacyCustomClassesForJSONList(List<Class<?>> classes){
//Default names (i.e., old format for custom JSON format)
List<Pair<String,Class>> list = new ArrayList<>();
for(Class<?> c : classes){
list.add(new Pair<String,Class>(c.getSimpleName(), c));
} }
registerLegacyCustomClassesForJSON(list); return legacyMapper;
}
/**
* Set of classes that can be registered for legacy deserialization.
*/
private static List<Class<?>> REGISTERABLE_CUSTOM_CLASSES = (List<Class<?>>) Arrays.<Class<?>>asList(
Transform.class,
ColumnAnalysis.class,
ColumnCondition.class,
Filter.class,
ColumnMetaData.class,
CalculateSortedRank.class,
Schema.class,
SequenceComparator.class,
SequenceSplit.class,
WindowFunction.class,
Writable.class,
WritableComparator.class
);
/**
* Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
* ONLY) for JSON deserialization, with custom names.<br>
* Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])}
* but is added in case it is required in non-standard circumstances.
*/
public static void registerLegacyCustomClassesForJSON(List<Pair<String,Class>> classes){
for(Pair<String,Class> p : classes){
String s = p.getFirst();
Class c = p.getRight();
//Check if it's a valid class to register...
boolean found = false;
for( Class<?> c2 : REGISTERABLE_CUSTOM_CLASSES){
if(c2.isAssignableFrom(c)){
Map<String,String> map = LegacyMappingHelper.legacyMappingForClass(c2);
map.put(p.getFirst(), p.getSecond().getName());
found = true;
}
}
if(!found){
throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " +
c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
}
}
}
/**
* Get the legacy JSON mapper for the specified class.<br>
*
* <b>NOTE</b>: This is intended for internal backward-compatibility use.
*
* Note to developers: The following JSON mappers are for handling legacy format JSON.
* Note that after 1.0.0-alpha, the JSON subtype format for Transforms, Filters, Conditions etc were changed from
* a wrapper object, to an "@class" field. However, to not break all saved transforms networks, these mappers are
* part of the solution.<br>
* <br>
* How legacy loading works (same pattern for all types - Transform, Filter, Condition etc)<br>
* 1. Transforms etc JSON that has a "@class" field are deserialized as normal<br>
* 2. Transforms JSON that don't have such a field are mapped (via Layer @JsonTypeInfo) to LegacyMappingHelper.TransformHelper<br>
* 3. LegacyMappingHelper.TransformHelper has a @JsonDeserialize annotation - we use LegacyMappingHelper.LegacyTransformDeserializer to handle it<br>
* 4. LegacyTransformDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names
* 5. BaseLegacyDeserializer (that LegacyTransformDeserializer extends) does a lookup and handles the deserialization
*
* Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format,
* as it'll fail due to not having the expected "@class" annotation.
* Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified
* class anyway. The ignoring is done via an annotation introspector, defined below in this class.
* However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of
* all types - if we did, then any nested types would fail (i.e., an Condition in a Transform - the Transform couldn't
* be deserialized correctly, as the annotation would be ignored).
*
*/
public static synchronized ObjectMapper getLegacyMapperFor(@NonNull Class<?> clazz){
if(!legacyMappers.containsKey(clazz)){
ObjectMapper m = new ObjectMapper();
configureMapper(m);
m.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.<Class>singletonList(clazz)));
legacyMappers.put(clazz, m);
}
return legacyMappers.get(clazz);
} }
/** /**
@ -237,61 +76,7 @@ public class JsonMappers {
ret.enable(SerializationFeature.INDENT_OUTPUT); ret.enable(SerializationFeature.INDENT_OUTPUT);
ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
ret.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); //Need this otherwise JsonProperty annotations on constructors won't be seen
} }
/**
* Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc.
* This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring
* a set of JsonTypeInfo annotations
*/
@AllArgsConstructor
private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector {
private List<Class> classList;
@Override
protected TypeResolverBuilder<?> _findTypeResolver(MapperConfig<?> config, Annotated ann, JavaType baseType) {
if(ann instanceof AnnotatedClass){
AnnotatedClass c = (AnnotatedClass)ann;
Class<?> annClass = c.getAnnotated();
boolean isAssignable = false;
for(Class c2 : classList){
if(c2.isAssignableFrom(annClass)){
isAssignable = true;
break;
}
}
if( isAssignable ){
AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations();
if(annotations == null || annotations.annotations() == null){
//Probably not necessary - but here for safety
return super._findTypeResolver(config, ann, baseType);
}
AnnotationMap newMap = null;
for(Annotation a : annotations.annotations()){
Class<?> annType = a.annotationType();
if(annType == JsonTypeInfo.class){
//Ignore the JsonTypeInfo annotation on the Layer class
continue;
}
if(newMap == null){
newMap = new AnnotationMap();
}
newMap.add(a);
}
if(newMap == null)
return null;
//Pass the remaining annotations (if any) to the original introspector
AnnotatedClass ann2 = c.withAnnotations(newMap);
return super._findTypeResolver(config, ann2, baseType);
}
}
return super._findTypeResolver(config, ann, baseType);
}
}
} }

View File

@ -1,41 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.transform.serde.legacy;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import org.datavec.api.transform.serde.JsonMappers;
import org.nd4j.serde.json.BaseLegacyDeserializer;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.util.Map;
@AllArgsConstructor
@Data
public class GenericLegacyDeserializer<T> extends BaseLegacyDeserializer<T> {
@Getter
protected final Class<T> deserializedType;
@Getter
protected final Map<String,String> legacyNamesMap;
@Override
public ObjectMapper getLegacyJsonMapper() {
return JsonMappers.getLegacyMapperFor(getDeserializedType());
}
}

View File

@ -0,0 +1,267 @@
package org.datavec.api.transform.serde.legacy;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.analysis.columns.*;
import org.datavec.api.transform.condition.BooleanCondition;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.condition.column.*;
import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.filter.FilterInvalidValues;
import org.datavec.api.transform.filter.InvalidNumColumns;
import org.datavec.api.transform.metadata.*;
import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.ReduceSequenceTransform;
import org.datavec.api.transform.sequence.SequenceComparator;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
import org.datavec.api.transform.sequence.comparator.StringComparator;
import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
import org.datavec.api.transform.sequence.window.TimeWindowFunction;
import org.datavec.api.transform.sequence.window.WindowFunction;
import org.datavec.api.transform.stringreduce.IStringReducer;
import org.datavec.api.transform.stringreduce.StringReducer;
import org.datavec.api.transform.transform.categorical.*;
import org.datavec.api.transform.transform.column.*;
import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
import org.datavec.api.transform.transform.doubletransform.*;
import org.datavec.api.transform.transform.integer.*;
import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
import org.datavec.api.transform.transform.string.*;
import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
import org.datavec.api.transform.transform.time.StringToTimeTransform;
import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.*;
import org.datavec.api.writable.comparator.*;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.databind.ObjectMapper;
/**
* This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override
* the existing annotations.
* In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field".
* We use these mixins to allow us to still load the old format
*
* @author Alex Black
*/
public class LegacyJsonFormat {
private LegacyJsonFormat(){ }
/**
* Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before
* @return Object mapper
*/
public static ObjectMapper legacyMapper(){
ObjectMapper om = new ObjectMapper();
om.addMixIn(Schema.class, SchemaMixin.class);
om.addMixIn(ColumnMetaData.class, ColumnMetaDataMixin.class);
om.addMixIn(Transform.class, TransformMixin.class);
om.addMixIn(Condition.class, ConditionMixin.class);
om.addMixIn(Writable.class, WritableMixin.class);
om.addMixIn(Filter.class, FilterMixin.class);
om.addMixIn(SequenceComparator.class, SequenceComparatorMixin.class);
om.addMixIn(SequenceSplit.class, SequenceSplitMixin.class);
om.addMixIn(WindowFunction.class, WindowFunctionMixin.class);
om.addMixIn(CalculateSortedRank.class, CalculateSortedRankMixin.class);
om.addMixIn(WritableComparator.class, WritableComparatorMixin.class);
om.addMixIn(ColumnAnalysis.class, ColumnAnalysisMixin.class);
om.addMixIn(IStringReducer.class, IStringReducerMixin.class);
return om;
}
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes({@JsonSubTypes.Type(value = Schema.class, name = "Schema"),
@JsonSubTypes.Type(value = SequenceSchema.class, name = "SequenceSchema")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class SchemaMixin { }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes({@JsonSubTypes.Type(value = BinaryMetaData.class, name = "Binary"),
@JsonSubTypes.Type(value = BooleanMetaData.class, name = "Boloean"),
@JsonSubTypes.Type(value = CategoricalMetaData.class, name = "Categorical"),
@JsonSubTypes.Type(value = DoubleMetaData.class, name = "Double"),
@JsonSubTypes.Type(value = FloatMetaData.class, name = "Float"),
@JsonSubTypes.Type(value = IntegerMetaData.class, name = "Integer"),
@JsonSubTypes.Type(value = LongMetaData.class, name = "Long"),
@JsonSubTypes.Type(value = NDArrayMetaData.class, name = "NDArray"),
@JsonSubTypes.Type(value = StringMetaData.class, name = "String"),
@JsonSubTypes.Type(value = TimeMetaData.class, name = "Time")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class ColumnMetaDataMixin { }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank"),
@JsonSubTypes.Type(value = CategoricalToIntegerTransform.class, name = "CategoricalToIntegerTransform"),
@JsonSubTypes.Type(value = CategoricalToOneHotTransform.class, name = "CategoricalToOneHotTransform"),
@JsonSubTypes.Type(value = IntegerToCategoricalTransform.class, name = "IntegerToCategoricalTransform"),
@JsonSubTypes.Type(value = StringToCategoricalTransform.class, name = "StringToCategoricalTransform"),
@JsonSubTypes.Type(value = DuplicateColumnsTransform.class, name = "DuplicateColumnsTransform"),
@JsonSubTypes.Type(value = RemoveColumnsTransform.class, name = "RemoveColumnsTransform"),
@JsonSubTypes.Type(value = RenameColumnsTransform.class, name = "RenameColumnsTransform"),
@JsonSubTypes.Type(value = ReorderColumnsTransform.class, name = "ReorderColumnsTransform"),
@JsonSubTypes.Type(value = ConditionalCopyValueTransform.class, name = "ConditionalCopyValueTransform"),
@JsonSubTypes.Type(value = ConditionalReplaceValueTransform.class, name = "ConditionalReplaceValueTransform"),
@JsonSubTypes.Type(value = ConditionalReplaceValueTransformWithDefault.class, name = "ConditionalReplaceValueTransformWithDefault"),
@JsonSubTypes.Type(value = DoubleColumnsMathOpTransform.class, name = "DoubleColumnsMathOpTransform"),
@JsonSubTypes.Type(value = DoubleMathOpTransform.class, name = "DoubleMathOpTransform"),
@JsonSubTypes.Type(value = Log2Normalizer.class, name = "Log2Normalizer"),
@JsonSubTypes.Type(value = MinMaxNormalizer.class, name = "MinMaxNormalizer"),
@JsonSubTypes.Type(value = StandardizeNormalizer.class, name = "StandardizeNormalizer"),
@JsonSubTypes.Type(value = SubtractMeanNormalizer.class, name = "SubtractMeanNormalizer"),
@JsonSubTypes.Type(value = IntegerColumnsMathOpTransform.class, name = "IntegerColumnsMathOpTransform"),
@JsonSubTypes.Type(value = IntegerMathOpTransform.class, name = "IntegerMathOpTransform"),
@JsonSubTypes.Type(value = ReplaceEmptyIntegerWithValueTransform.class, name = "ReplaceEmptyIntegerWithValueTransform"),
@JsonSubTypes.Type(value = ReplaceInvalidWithIntegerTransform.class, name = "ReplaceInvalidWithIntegerTransform"),
@JsonSubTypes.Type(value = LongColumnsMathOpTransform.class, name = "LongColumnsMathOpTransform"),
@JsonSubTypes.Type(value = LongMathOpTransform.class, name = "LongMathOpTransform"),
@JsonSubTypes.Type(value = MapAllStringsExceptListTransform.class, name = "MapAllStringsExceptListTransform"),
@JsonSubTypes.Type(value = RemoveWhiteSpaceTransform.class, name = "RemoveWhiteSpaceTransform"),
@JsonSubTypes.Type(value = ReplaceEmptyStringTransform.class, name = "ReplaceEmptyStringTransform"),
@JsonSubTypes.Type(value = ReplaceStringTransform.class, name = "ReplaceStringTransform"),
@JsonSubTypes.Type(value = StringListToCategoricalSetTransform.class, name = "StringListToCategoricalSetTransform"),
@JsonSubTypes.Type(value = StringMapTransform.class, name = "StringMapTransform"),
@JsonSubTypes.Type(value = DeriveColumnsFromTimeTransform.class, name = "DeriveColumnsFromTimeTransform"),
@JsonSubTypes.Type(value = StringToTimeTransform.class, name = "StringToTimeTransform"),
@JsonSubTypes.Type(value = TimeMathOpTransform.class, name = "TimeMathOpTransform"),
@JsonSubTypes.Type(value = ReduceSequenceByWindowTransform.class, name = "ReduceSequenceByWindowTransform"),
@JsonSubTypes.Type(value = DoubleMathFunctionTransform.class, name = "DoubleMathFunctionTransform"),
@JsonSubTypes.Type(value = AddConstantColumnTransform.class, name = "AddConstantColumnTransform"),
@JsonSubTypes.Type(value = RemoveAllColumnsExceptForTransform.class, name = "RemoveAllColumnsExceptForTransform"),
@JsonSubTypes.Type(value = ParseDoubleTransform.class, name = "ParseDoubleTransform"),
@JsonSubTypes.Type(value = ConvertToString.class, name = "ConvertToStringTransform"),
@JsonSubTypes.Type(value = AppendStringColumnTransform.class, name = "AppendStringColumnTransform"),
@JsonSubTypes.Type(value = SequenceDifferenceTransform.class, name = "SequenceDifferenceTransform"),
@JsonSubTypes.Type(value = ReduceSequenceTransform.class, name = "ReduceSequenceTransform"),
@JsonSubTypes.Type(value = SequenceMovingWindowReduceTransform.class, name = "SequenceMovingWindowReduceTransform"),
@JsonSubTypes.Type(value = IntegerToOneHotTransform.class, name = "IntegerToOneHotTransform"),
@JsonSubTypes.Type(value = SequenceTrimTransform.class, name = "SequenceTrimTransform"),
@JsonSubTypes.Type(value = SequenceOffsetTransform.class, name = "SequenceOffsetTransform"),
@JsonSubTypes.Type(value = NDArrayColumnsMathOpTransform.class, name = "NDArrayColumnsMathOpTransform"),
@JsonSubTypes.Type(value = NDArrayDistanceTransform.class, name = "NDArrayDistanceTransform"),
@JsonSubTypes.Type(value = NDArrayMathFunctionTransform.class, name = "NDArrayMathFunctionTransform"),
@JsonSubTypes.Type(value = NDArrayScalarOpTransform.class, name = "NDArrayScalarOpTransform"),
@JsonSubTypes.Type(value = ChangeCaseStringTransform.class, name = "ChangeCaseStringTransform"),
@JsonSubTypes.Type(value = ConcatenateStringColumns.class, name = "ConcatenateStringColumns"),
@JsonSubTypes.Type(value = StringListToCountsNDArrayTransform.class, name = "StringListToCountsNDArrayTransform"),
@JsonSubTypes.Type(value = StringListToIndicesNDArrayTransform.class, name = "StringListToIndicesNDArrayTransform"),
@JsonSubTypes.Type(value = PivotTransform.class, name = "PivotTransform"),
@JsonSubTypes.Type(value = TextToCharacterIndexTransform.class, name = "TextToCharacterIndexTransform")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class TransformMixin { }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = TrivialColumnCondition.class, name = "TrivialColumnCondition"),
@JsonSubTypes.Type(value = CategoricalColumnCondition.class, name = "CategoricalColumnCondition"),
@JsonSubTypes.Type(value = DoubleColumnCondition.class, name = "DoubleColumnCondition"),
@JsonSubTypes.Type(value = IntegerColumnCondition.class, name = "IntegerColumnCondition"),
@JsonSubTypes.Type(value = LongColumnCondition.class, name = "LongColumnCondition"),
@JsonSubTypes.Type(value = NullWritableColumnCondition.class, name = "NullWritableColumnCondition"),
@JsonSubTypes.Type(value = StringColumnCondition.class, name = "StringColumnCondition"),
@JsonSubTypes.Type(value = TimeColumnCondition.class, name = "TimeColumnCondition"),
@JsonSubTypes.Type(value = StringRegexColumnCondition.class, name = "StringRegexColumnCondition"),
@JsonSubTypes.Type(value = BooleanCondition.class, name = "BooleanCondition"),
@JsonSubTypes.Type(value = NaNColumnCondition.class, name = "NaNColumnCondition"),
@JsonSubTypes.Type(value = InfiniteColumnCondition.class, name = "InfiniteColumnCondition"),
@JsonSubTypes.Type(value = SequenceLengthCondition.class, name = "SequenceLengthCondition")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class ConditionMixin { }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = ArrayWritable.class, name = "ArrayWritable"),
@JsonSubTypes.Type(value = BooleanWritable.class, name = "BooleanWritable"),
@JsonSubTypes.Type(value = ByteWritable.class, name = "ByteWritable"),
@JsonSubTypes.Type(value = DoubleWritable.class, name = "DoubleWritable"),
@JsonSubTypes.Type(value = FloatWritable.class, name = "FloatWritable"),
@JsonSubTypes.Type(value = IntWritable.class, name = "IntWritable"),
@JsonSubTypes.Type(value = LongWritable.class, name = "LongWritable"),
@JsonSubTypes.Type(value = NullWritable.class, name = "NullWritable"),
@JsonSubTypes.Type(value = Text.class, name = "Text"),
@JsonSubTypes.Type(value = BytesWritable.class, name = "BytesWritable")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class WritableMixin { }
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = ConditionFilter.class, name = "ConditionFilter"),
@JsonSubTypes.Type(value = FilterInvalidValues.class, name = "FilterInvalidValues"),
@JsonSubTypes.Type(value = InvalidNumColumns.class, name = "InvalidNumCols")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class FilterMixin { }
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = NumericalColumnComparator.class, name = "NumericalColumnComparator"),
@JsonSubTypes.Type(value = StringComparator.class, name = "StringComparator")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class SequenceComparatorMixin { }
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = SequenceSplitTimeSeparation.class, name = "SequenceSplitTimeSeparation"),
@JsonSubTypes.Type(value = SplitMaxLengthSequence.class, name = "SplitMaxLengthSequence")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class SequenceSplitMixin { }
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = TimeWindowFunction.class, name = "TimeWindowFunction"),
@JsonSubTypes.Type(value = OverlappingTimeWindowFunction.class, name = "OverlappingTimeWindowFunction")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class WindowFunctionMixin { }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class CalculateSortedRankMixin { }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = DoubleWritableComparator.class, name = "DoubleWritableComparator"),
@JsonSubTypes.Type(value = FloatWritableComparator.class, name = "FloatWritableComparator"),
@JsonSubTypes.Type(value = IntWritableComparator.class, name = "IntWritableComparator"),
@JsonSubTypes.Type(value = LongWritableComparator.class, name = "LongWritableComparator"),
@JsonSubTypes.Type(value = TextWritableComparator.class, name = "TextWritableComparator")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class WritableComparatorMixin { }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = BytesAnalysis.class, name = "BytesAnalysis"),
@JsonSubTypes.Type(value = CategoricalAnalysis.class, name = "CategoricalAnalysis"),
@JsonSubTypes.Type(value = DoubleAnalysis.class, name = "DoubleAnalysis"),
@JsonSubTypes.Type(value = IntegerAnalysis.class, name = "IntegerAnalysis"),
@JsonSubTypes.Type(value = LongAnalysis.class, name = "LongAnalysis"),
@JsonSubTypes.Type(value = StringAnalysis.class, name = "StringAnalysis"),
@JsonSubTypes.Type(value = TimeAnalysis.class, name = "TimeAnalysis")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class ColumnAnalysisMixin{ }
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT)
@JsonSubTypes(value = {@JsonSubTypes.Type(value = StringReducer.class, name = "StringReducer")})
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class IStringReducerMixin{ }
}

View File

@ -1,535 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.transform.serde.legacy;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.analysis.columns.*;
import org.datavec.api.transform.condition.BooleanCondition;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.condition.column.*;
import org.datavec.api.transform.condition.sequence.SequenceLengthCondition;
import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.filter.FilterInvalidValues;
import org.datavec.api.transform.filter.InvalidNumColumns;
import org.datavec.api.transform.metadata.*;
import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform;
import org.datavec.api.transform.ndarray.NDArrayDistanceTransform;
import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform;
import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.ReduceSequenceTransform;
import org.datavec.api.transform.sequence.SequenceComparator;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator;
import org.datavec.api.transform.sequence.comparator.StringComparator;
import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation;
import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence;
import org.datavec.api.transform.sequence.trim.SequenceTrimTransform;
import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction;
import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform;
import org.datavec.api.transform.sequence.window.TimeWindowFunction;
import org.datavec.api.transform.sequence.window.WindowFunction;
import org.datavec.api.transform.stringreduce.IStringReducer;
import org.datavec.api.transform.stringreduce.StringReducer;
import org.datavec.api.transform.transform.categorical.*;
import org.datavec.api.transform.transform.column.*;
import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform;
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform;
import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault;
import org.datavec.api.transform.transform.doubletransform.*;
import org.datavec.api.transform.transform.integer.*;
import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform;
import org.datavec.api.transform.transform.longtransform.LongMathOpTransform;
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
import org.datavec.api.transform.transform.nlp.TextToTermIndexSequenceTransform;
import org.datavec.api.transform.transform.parse.ParseDoubleTransform;
import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform;
import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform;
import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform;
import org.datavec.api.transform.transform.string.*;
import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
import org.datavec.api.transform.transform.time.StringToTimeTransform;
import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.*;
import org.datavec.api.writable.comparator.*;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import java.util.HashMap;
import java.util.Map;
public class LegacyMappingHelper {
public static Map<String,String> legacyMappingForClass(Class c){
//Need to be able to get the map - and they need to be mutable...
switch (c.getSimpleName()){
case "Transform":
return getLegacyMappingImageTransform();
case "ColumnAnalysis":
return getLegacyMappingColumnAnalysis();
case "Condition":
return getLegacyMappingCondition();
case "Filter":
return getLegacyMappingFilter();
case "ColumnMetaData":
return mapColumnMetaData;
case "CalculateSortedRank":
return mapCalculateSortedRank;
case "Schema":
return mapSchema;
case "SequenceComparator":
return mapSequenceComparator;
case "SequenceSplit":
return mapSequenceSplit;
case "WindowFunction":
return mapWindowFunction;
case "IStringReducer":
return mapIStringReducer;
case "Writable":
return mapWritable;
case "WritableComparator":
return mapWritableComparator;
case "ImageTransform":
return mapImageTransform;
default:
//Should never happen
throw new IllegalArgumentException("No legacy mapping available for class " + c.getName());
}
}
private static Map<String,String> mapTransform;
private static Map<String,String> mapColumnAnalysis;
private static Map<String,String> mapCondition;
private static Map<String,String> mapFilter;
private static Map<String,String> mapColumnMetaData;
private static Map<String,String> mapCalculateSortedRank;
private static Map<String,String> mapSchema;
private static Map<String,String> mapSequenceComparator;
private static Map<String,String> mapSequenceSplit;
private static Map<String,String> mapWindowFunction;
private static Map<String,String> mapIStringReducer;
private static Map<String,String> mapWritable;
private static Map<String,String> mapWritableComparator;
private static Map<String,String> mapImageTransform;
private static synchronized Map<String,String> getLegacyMappingTransform(){
if(mapTransform == null) {
//The following classes all used their class short name
Map<String, String> m = new HashMap<>();
m.put("CategoricalToIntegerTransform", CategoricalToIntegerTransform.class.getName());
m.put("CategoricalToOneHotTransform", CategoricalToOneHotTransform.class.getName());
m.put("IntegerToCategoricalTransform", IntegerToCategoricalTransform.class.getName());
m.put("StringToCategoricalTransform", StringToCategoricalTransform.class.getName());
m.put("DuplicateColumnsTransform", DuplicateColumnsTransform.class.getName());
m.put("RemoveColumnsTransform", RemoveColumnsTransform.class.getName());
m.put("RenameColumnsTransform", RenameColumnsTransform.class.getName());
m.put("ReorderColumnsTransform", ReorderColumnsTransform.class.getName());
m.put("ConditionalCopyValueTransform", ConditionalCopyValueTransform.class.getName());
m.put("ConditionalReplaceValueTransform", ConditionalReplaceValueTransform.class.getName());
m.put("ConditionalReplaceValueTransformWithDefault", ConditionalReplaceValueTransformWithDefault.class.getName());
m.put("DoubleColumnsMathOpTransform", DoubleColumnsMathOpTransform.class.getName());
m.put("DoubleMathOpTransform", DoubleMathOpTransform.class.getName());
m.put("Log2Normalizer", Log2Normalizer.class.getName());
m.put("MinMaxNormalizer", MinMaxNormalizer.class.getName());
m.put("StandardizeNormalizer", StandardizeNormalizer.class.getName());
m.put("SubtractMeanNormalizer", SubtractMeanNormalizer.class.getName());
m.put("IntegerColumnsMathOpTransform", IntegerColumnsMathOpTransform.class.getName());
m.put("IntegerMathOpTransform", IntegerMathOpTransform.class.getName());
m.put("ReplaceEmptyIntegerWithValueTransform", ReplaceEmptyIntegerWithValueTransform.class.getName());
m.put("ReplaceInvalidWithIntegerTransform", ReplaceInvalidWithIntegerTransform.class.getName());
m.put("LongColumnsMathOpTransform", LongColumnsMathOpTransform.class.getName());
m.put("LongMathOpTransform", LongMathOpTransform.class.getName());
m.put("MapAllStringsExceptListTransform", MapAllStringsExceptListTransform.class.getName());
m.put("RemoveWhiteSpaceTransform", RemoveWhiteSpaceTransform.class.getName());
m.put("ReplaceEmptyStringTransform", ReplaceEmptyStringTransform.class.getName());
m.put("ReplaceStringTransform", ReplaceStringTransform.class.getName());
m.put("StringListToCategoricalSetTransform", StringListToCategoricalSetTransform.class.getName());
m.put("StringMapTransform", StringMapTransform.class.getName());
m.put("DeriveColumnsFromTimeTransform", DeriveColumnsFromTimeTransform.class.getName());
m.put("StringToTimeTransform", StringToTimeTransform.class.getName());
m.put("TimeMathOpTransform", TimeMathOpTransform.class.getName());
m.put("ReduceSequenceByWindowTransform", ReduceSequenceByWindowTransform.class.getName());
m.put("DoubleMathFunctionTransform", DoubleMathFunctionTransform.class.getName());
m.put("AddConstantColumnTransform", AddConstantColumnTransform.class.getName());
m.put("RemoveAllColumnsExceptForTransform", RemoveAllColumnsExceptForTransform.class.getName());
m.put("ParseDoubleTransform", ParseDoubleTransform.class.getName());
m.put("ConvertToStringTransform", ConvertToString.class.getName());
m.put("AppendStringColumnTransform", AppendStringColumnTransform.class.getName());
m.put("SequenceDifferenceTransform", SequenceDifferenceTransform.class.getName());
m.put("ReduceSequenceTransform", ReduceSequenceTransform.class.getName());
m.put("SequenceMovingWindowReduceTransform", SequenceMovingWindowReduceTransform.class.getName());
m.put("IntegerToOneHotTransform", IntegerToOneHotTransform.class.getName());
m.put("SequenceTrimTransform", SequenceTrimTransform.class.getName());
m.put("SequenceOffsetTransform", SequenceOffsetTransform.class.getName());
m.put("NDArrayColumnsMathOpTransform", NDArrayColumnsMathOpTransform.class.getName());
m.put("NDArrayDistanceTransform", NDArrayDistanceTransform.class.getName());
m.put("NDArrayMathFunctionTransform", NDArrayMathFunctionTransform.class.getName());
m.put("NDArrayScalarOpTransform", NDArrayScalarOpTransform.class.getName());
m.put("ChangeCaseStringTransform", ChangeCaseStringTransform.class.getName());
m.put("ConcatenateStringColumns", ConcatenateStringColumns.class.getName());
m.put("StringListToCountsNDArrayTransform", StringListToCountsNDArrayTransform.class.getName());
m.put("StringListToIndicesNDArrayTransform", StringListToIndicesNDArrayTransform.class.getName());
m.put("PivotTransform", PivotTransform.class.getName());
m.put("TextToCharacterIndexTransform", TextToCharacterIndexTransform.class.getName());
//The following never had subtype annotations, and hence will have had the default name:
m.put(TextToTermIndexSequenceTransform.class.getSimpleName(), TextToTermIndexSequenceTransform.class.getName());
m.put(ConvertToInteger.class.getSimpleName(), ConvertToInteger.class.getName());
m.put(ConvertToDouble.class.getSimpleName(), ConvertToDouble.class.getName());
mapTransform = m;
}
return mapTransform;
}
private static Map<String,String> getLegacyMappingColumnAnalysis(){
if(mapColumnAnalysis == null) {
Map<String, String> m = new HashMap<>();
m.put("BytesAnalysis", BytesAnalysis.class.getName());
m.put("CategoricalAnalysis", CategoricalAnalysis.class.getName());
m.put("DoubleAnalysis", DoubleAnalysis.class.getName());
m.put("IntegerAnalysis", IntegerAnalysis.class.getName());
m.put("LongAnalysis", LongAnalysis.class.getName());
m.put("StringAnalysis", StringAnalysis.class.getName());
m.put("TimeAnalysis", TimeAnalysis.class.getName());
//The following never had subtype annotations, and hence will have had the default name:
m.put(NDArrayAnalysis.class.getSimpleName(), NDArrayAnalysis.class.getName());
mapColumnAnalysis = m;
}
return mapColumnAnalysis;
}
private static Map<String,String> getLegacyMappingCondition(){
if(mapCondition == null) {
Map<String, String> m = new HashMap<>();
m.put("TrivialColumnCondition", TrivialColumnCondition.class.getName());
m.put("CategoricalColumnCondition", CategoricalColumnCondition.class.getName());
m.put("DoubleColumnCondition", DoubleColumnCondition.class.getName());
m.put("IntegerColumnCondition", IntegerColumnCondition.class.getName());
m.put("LongColumnCondition", LongColumnCondition.class.getName());
m.put("NullWritableColumnCondition", NullWritableColumnCondition.class.getName());
m.put("StringColumnCondition", StringColumnCondition.class.getName());
m.put("TimeColumnCondition", TimeColumnCondition.class.getName());
m.put("StringRegexColumnCondition", StringRegexColumnCondition.class.getName());
m.put("BooleanCondition", BooleanCondition.class.getName());
m.put("NaNColumnCondition", NaNColumnCondition.class.getName());
m.put("InfiniteColumnCondition", InfiniteColumnCondition.class.getName());
m.put("SequenceLengthCondition", SequenceLengthCondition.class.getName());
//The following never had subtype annotations, and hence will have had the default name:
m.put(InvalidValueColumnCondition.class.getSimpleName(), InvalidValueColumnCondition.class.getName());
m.put(BooleanColumnCondition.class.getSimpleName(), BooleanColumnCondition.class.getName());
mapCondition = m;
}
return mapCondition;
}
private static Map<String,String> getLegacyMappingFilter(){
if(mapFilter == null) {
Map<String, String> m = new HashMap<>();
m.put("ConditionFilter", ConditionFilter.class.getName());
m.put("FilterInvalidValues", FilterInvalidValues.class.getName());
m.put("InvalidNumCols", InvalidNumColumns.class.getName());
mapFilter = m;
}
return mapFilter;
}
private static Map<String,String> getLegacyMappingColumnMetaData(){
if(mapColumnMetaData == null) {
Map<String, String> m = new HashMap<>();
m.put("Categorical", CategoricalMetaData.class.getName());
m.put("Double", DoubleMetaData.class.getName());
m.put("Float", FloatMetaData.class.getName());
m.put("Integer", IntegerMetaData.class.getName());
m.put("Long", LongMetaData.class.getName());
m.put("String", StringMetaData.class.getName());
m.put("Time", TimeMetaData.class.getName());
m.put("NDArray", NDArrayMetaData.class.getName());
//The following never had subtype annotations, and hence will have had the default name:
m.put(BooleanMetaData.class.getSimpleName(), BooleanMetaData.class.getName());
m.put(BinaryMetaData.class.getSimpleName(), BinaryMetaData.class.getName());
mapColumnMetaData = m;
}
return mapColumnMetaData;
}
private static Map<String,String> getLegacyMappingCalculateSortedRank(){
if(mapCalculateSortedRank == null) {
Map<String, String> m = new HashMap<>();
m.put("CalculateSortedRank", CalculateSortedRank.class.getName());
mapCalculateSortedRank = m;
}
return mapCalculateSortedRank;
}
private static Map<String,String> getLegacyMappingSchema(){
if(mapSchema == null) {
Map<String, String> m = new HashMap<>();
m.put("Schema", Schema.class.getName());
m.put("SequenceSchema", SequenceSchema.class.getName());
mapSchema = m;
}
return mapSchema;
}
private static Map<String,String> getLegacyMappingSequenceComparator(){
if(mapSequenceComparator == null) {
Map<String, String> m = new HashMap<>();
m.put("NumericalColumnComparator", NumericalColumnComparator.class.getName());
m.put("StringComparator", StringComparator.class.getName());
mapSequenceComparator = m;
}
return mapSequenceComparator;
}
private static Map<String,String> getLegacyMappingSequenceSplit(){
if(mapSequenceSplit == null) {
Map<String, String> m = new HashMap<>();
m.put("SequenceSplitTimeSeparation", SequenceSplitTimeSeparation.class.getName());
m.put("SplitMaxLengthSequence", SplitMaxLengthSequence.class.getName());
mapSequenceSplit = m;
}
return mapSequenceSplit;
}
private static Map<String,String> getLegacyMappingWindowFunction(){
if(mapWindowFunction == null) {
Map<String, String> m = new HashMap<>();
m.put("TimeWindowFunction", TimeWindowFunction.class.getName());
m.put("OverlappingTimeWindowFunction", OverlappingTimeWindowFunction.class.getName());
mapWindowFunction = m;
}
return mapWindowFunction;
}
private static Map<String,String> getLegacyMappingIStringReducer(){
if(mapIStringReducer == null) {
Map<String, String> m = new HashMap<>();
m.put("StringReducer", StringReducer.class.getName());
mapIStringReducer = m;
}
return mapIStringReducer;
}
private static Map<String,String> getLegacyMappingWritable(){
if (mapWritable == null) {
Map<String, String> m = new HashMap<>();
m.put("ArrayWritable", ArrayWritable.class.getName());
m.put("BooleanWritable", BooleanWritable.class.getName());
m.put("ByteWritable", ByteWritable.class.getName());
m.put("DoubleWritable", DoubleWritable.class.getName());
m.put("FloatWritable", FloatWritable.class.getName());
m.put("IntWritable", IntWritable.class.getName());
m.put("LongWritable", LongWritable.class.getName());
m.put("NullWritable", NullWritable.class.getName());
m.put("Text", Text.class.getName());
m.put("BytesWritable", BytesWritable.class.getName());
//The following never had subtype annotations, and hence will have had the default name:
m.put(NDArrayWritable.class.getSimpleName(), NDArrayWritable.class.getName());
mapWritable = m;
}
return mapWritable;
}
private static Map<String,String> getLegacyMappingWritableComparator(){
if(mapWritableComparator == null) {
Map<String, String> m = new HashMap<>();
m.put("DoubleWritableComparator", DoubleWritableComparator.class.getName());
m.put("FloatWritableComparator", FloatWritableComparator.class.getName());
m.put("IntWritableComparator", IntWritableComparator.class.getName());
m.put("LongWritableComparator", LongWritableComparator.class.getName());
m.put("TextWritableComparator", TextWritableComparator.class.getName());
//The following never had subtype annotations, and hence will have had the default name:
m.put(ByteWritable.Comparator.class.getSimpleName(), ByteWritable.Comparator.class.getName());
m.put(FloatWritable.Comparator.class.getSimpleName(), FloatWritable.Comparator.class.getName());
m.put(IntWritable.Comparator.class.getSimpleName(), IntWritable.Comparator.class.getName());
m.put(BooleanWritable.Comparator.class.getSimpleName(), BooleanWritable.Comparator.class.getName());
m.put(LongWritable.Comparator.class.getSimpleName(), LongWritable.Comparator.class.getName());
m.put(Text.Comparator.class.getSimpleName(), Text.Comparator.class.getName());
m.put(LongWritable.DecreasingComparator.class.getSimpleName(), LongWritable.DecreasingComparator.class.getName());
m.put(DoubleWritable.Comparator.class.getSimpleName(), DoubleWritable.Comparator.class.getName());
mapWritableComparator = m;
}
return mapWritableComparator;
}
public static Map<String,String> getLegacyMappingImageTransform(){
if(mapImageTransform == null) {
Map<String, String> m = new HashMap<>();
m.put("EqualizeHistTransform", "org.datavec.image.transform.EqualizeHistTransform");
m.put("RotateImageTransform", "org.datavec.image.transform.RotateImageTransform");
m.put("ColorConversionTransform", "org.datavec.image.transform.ColorConversionTransform");
m.put("WarpImageTransform", "org.datavec.image.transform.WarpImageTransform");
m.put("BoxImageTransform", "org.datavec.image.transform.BoxImageTransform");
m.put("CropImageTransform", "org.datavec.image.transform.CropImageTransform");
m.put("FilterImageTransform", "org.datavec.image.transform.FilterImageTransform");
m.put("FlipImageTransform", "org.datavec.image.transform.FlipImageTransform");
m.put("LargestBlobCropTransform", "org.datavec.image.transform.LargestBlobCropTransform");
m.put("ResizeImageTransform", "org.datavec.image.transform.ResizeImageTransform");
m.put("RandomCropTransform", "org.datavec.image.transform.RandomCropTransform");
m.put("ScaleImageTransform", "org.datavec.image.transform.ScaleImageTransform");
mapImageTransform = m;
}
return mapImageTransform;
}
@JsonDeserialize(using = LegacyTransformDeserializer.class)
public static class TransformHelper { }
public static class LegacyTransformDeserializer extends GenericLegacyDeserializer<Transform> {
public LegacyTransformDeserializer() {
super(Transform.class, getLegacyMappingTransform());
}
}
@JsonDeserialize(using = LegacyColumnAnalysisDeserializer.class)
public static class ColumnAnalysisHelper { }
public static class LegacyColumnAnalysisDeserializer extends GenericLegacyDeserializer<ColumnAnalysis> {
public LegacyColumnAnalysisDeserializer() {
super(ColumnAnalysis.class, getLegacyMappingColumnAnalysis());
}
}
@JsonDeserialize(using = LegacyConditionDeserializer.class)
public static class ConditionHelper { }
public static class LegacyConditionDeserializer extends GenericLegacyDeserializer<Condition> {
public LegacyConditionDeserializer() {
super(Condition.class, getLegacyMappingCondition());
}
}
@JsonDeserialize(using = LegacyFilterDeserializer.class)
public static class FilterHelper { }
public static class LegacyFilterDeserializer extends GenericLegacyDeserializer<Filter> {
public LegacyFilterDeserializer() {
super(Filter.class, getLegacyMappingFilter());
}
}
@JsonDeserialize(using = LegacyColumnMetaDataDeserializer.class)
public static class ColumnMetaDataHelper { }
public static class LegacyColumnMetaDataDeserializer extends GenericLegacyDeserializer<ColumnMetaData> {
public LegacyColumnMetaDataDeserializer() {
super(ColumnMetaData.class, getLegacyMappingColumnMetaData());
}
}
@JsonDeserialize(using = LegacyCalculateSortedRankDeserializer.class)
public static class CalculateSortedRankHelper { }
public static class LegacyCalculateSortedRankDeserializer extends GenericLegacyDeserializer<CalculateSortedRank> {
public LegacyCalculateSortedRankDeserializer() {
super(CalculateSortedRank.class, getLegacyMappingCalculateSortedRank());
}
}
@JsonDeserialize(using = LegacySchemaDeserializer.class)
public static class SchemaHelper { }
public static class LegacySchemaDeserializer extends GenericLegacyDeserializer<Schema> {
public LegacySchemaDeserializer() {
super(Schema.class, getLegacyMappingSchema());
}
}
@JsonDeserialize(using = LegacySequenceComparatorDeserializer.class)
public static class SequenceComparatorHelper { }
public static class LegacySequenceComparatorDeserializer extends GenericLegacyDeserializer<SequenceComparator> {
public LegacySequenceComparatorDeserializer() {
super(SequenceComparator.class, getLegacyMappingSequenceComparator());
}
}
@JsonDeserialize(using = LegacySequenceSplitDeserializer.class)
public static class SequenceSplitHelper { }
public static class LegacySequenceSplitDeserializer extends GenericLegacyDeserializer<SequenceSplit> {
public LegacySequenceSplitDeserializer() {
super(SequenceSplit.class, getLegacyMappingSequenceSplit());
}
}
@JsonDeserialize(using = LegacyWindowFunctionDeserializer.class)
public static class WindowFunctionHelper { }
public static class LegacyWindowFunctionDeserializer extends GenericLegacyDeserializer<WindowFunction> {
public LegacyWindowFunctionDeserializer() {
super(WindowFunction.class, getLegacyMappingWindowFunction());
}
}
@JsonDeserialize(using = LegacyIStringReducerDeserializer.class)
public static class IStringReducerHelper { }
public static class LegacyIStringReducerDeserializer extends GenericLegacyDeserializer<IStringReducer> {
public LegacyIStringReducerDeserializer() {
super(IStringReducer.class, getLegacyMappingIStringReducer());
}
}
@JsonDeserialize(using = LegacyWritableDeserializer.class)
public static class WritableHelper { }
public static class LegacyWritableDeserializer extends GenericLegacyDeserializer<Writable> {
public LegacyWritableDeserializer() {
super(Writable.class, getLegacyMappingWritable());
}
}
@JsonDeserialize(using = LegacyWritableComparatorDeserializer.class)
public static class WritableComparatorHelper { }
public static class LegacyWritableComparatorDeserializer extends GenericLegacyDeserializer<WritableComparator> {
public LegacyWritableComparatorDeserializer() {
super(WritableComparator.class, getLegacyMappingWritableComparator());
}
}
}

View File

@ -17,7 +17,6 @@
package org.datavec.api.transform.stringreduce; package org.datavec.api.transform.stringreduce;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
@ -31,8 +30,7 @@ import java.util.List;
* a single List<Writable> * a single List<Writable>
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.IStringReducerHelper.class)
public interface IStringReducer extends Serializable { public interface IStringReducer extends Serializable {
/** /**

View File

@ -16,7 +16,7 @@
package org.datavec.api.util.ndarray; package org.datavec.api.util.ndarray;
import com.google.common.base.Preconditions; import org.nd4j.shade.guava.base.Preconditions;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import lombok.NonNull; import lombok.NonNull;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;

View File

@ -17,7 +17,7 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import com.google.common.math.DoubleMath; import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator; import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -17,7 +17,7 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import com.google.common.math.DoubleMath; import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator; import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -17,7 +17,7 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import com.google.common.math.DoubleMath; import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator; import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -17,7 +17,7 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import com.google.common.math.DoubleMath; import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator; import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -17,7 +17,7 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import com.google.common.math.DoubleMath; import org.nd4j.shade.guava.math.DoubleMath;
import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparable;
import org.datavec.api.io.WritableComparator; import org.datavec.api.io.WritableComparator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;

View File

@ -16,7 +16,6 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.DataInput; import java.io.DataInput;
@ -60,8 +59,7 @@ import java.io.Serializable;
* } * }
* </pre></blockquote></p> * </pre></blockquote></p>
*/ */
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.WritableHelper.class)
public interface Writable extends Serializable { public interface Writable extends Serializable {
/** /**
* Serialize the fields of this object to <code>out</code>. * Serialize the fields of this object to <code>out</code>.

View File

@ -16,7 +16,7 @@
package org.datavec.api.writable.batch; package org.datavec.api.writable.batch;
import com.google.common.base.Preconditions; import org.nd4j.shade.guava.base.Preconditions;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;

View File

@ -16,16 +16,13 @@
package org.datavec.api.writable.comparator; package org.datavec.api.writable.comparator;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.io.Serializable; import java.io.Serializable;
import java.util.Comparator; import java.util.Comparator;
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyMappingHelper.WritableComparatorHelper.class)
public interface WritableComparator extends Comparator<Writable>, Serializable { public interface WritableComparator extends Comparator<Writable>, Serializable {
} }

View File

@ -16,7 +16,7 @@
package org.datavec.api.split; package org.datavec.api.split;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.ParentPathLabelGenerator;

View File

@ -16,7 +16,7 @@
package org.datavec.api.split.parittion; package org.datavec.api.split.parittion;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;

View File

@ -78,8 +78,9 @@ public class TestJsonYaml {
public void testMissingPrimitives() { public void testMissingPrimitives() {
Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build(); Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build();
//Legacy format JSON
String strJson = "{\n" + " \"Schema\" : {\n" + " \"columns\" : [ {\n" + " \"Double\" : {\n" String strJson = "{\n" + " \"Schema\" : {\n"
+ " \"columns\" : [ {\n" + " \"Double\" : {\n"
+ " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" + + " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" +
//" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test //" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test
//" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test //" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test

View File

@ -16,7 +16,7 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import com.google.common.collect.Lists; import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
import org.junit.Test; import org.junit.Test;

View File

@ -34,42 +34,6 @@
<artifactId>nd4j-arrow</artifactId> <artifactId>nd4j-arrow</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${spark2.jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${spark2.jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>${spark2.jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-yaml</artifactId>
<version>${spark2.jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-xml</artifactId>
<version>${spark2.jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-joda</artifactId>
<version>${spark2.jackson.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
@ -80,11 +44,6 @@
<artifactId>hppc</artifactId> <artifactId>hppc</artifactId>
<version>${hppc.version}</version> <version>${hppc.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.arrow</groupId> <groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId> <artifactId>arrow-vector</artifactId>

View File

@ -43,12 +43,6 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<!-- Do not depend on FFmpeg by default due to licensing concerns. --> <!-- Do not depend on FFmpeg by default due to licensing concerns. -->
<!-- <!--
@ -60,7 +54,6 @@
--> -->
</dependencies> </dependencies>
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>

View File

@ -31,20 +31,12 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>
<version>${logback.version}</version> <version>${logback.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-buffer</artifactId> <artifactId>nd4j-buffer</artifactId>
@ -75,7 +67,6 @@
<artifactId>imageio-bmp</artifactId> <artifactId>imageio-bmp</artifactId>
<version>3.1.1</version> <version>3.1.1</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.google.android</groupId> <groupId>com.google.android</groupId>
<artifactId>android</artifactId> <artifactId>android</artifactId>
@ -88,7 +79,6 @@
</exclusions> </exclusions>
<optional>true</optional> <optional>true</optional>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>javacpp</artifactId> <artifactId>javacpp</artifactId>
@ -99,25 +89,21 @@
<artifactId>javacv</artifactId> <artifactId>javacv</artifactId>
<version>${javacv.version}</version> <version>${javacv.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>opencv-platform</artifactId> <artifactId>opencv-platform</artifactId>
<version>${opencv.version}-${javacpp-presets.version}</version> <version>${opencv.version}-${javacpp-presets.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>leptonica-platform</artifactId> <artifactId>leptonica-platform</artifactId>
<version>${leptonica.version}-${javacpp-presets.version}</version> <version>${leptonica.version}-${javacpp-presets.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.bytedeco</groupId> <groupId>org.bytedeco</groupId>
<artifactId>hdf5-platform</artifactId> <artifactId>hdf5-platform</artifactId>
<version>${hdf5.version}-${javacpp-presets.version}</version> <version>${hdf5.version}-${javacpp-presets.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>
<build> <build>
@ -143,5 +129,4 @@
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.1</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -16,7 +16,7 @@
package org.datavec.image.recordreader; package org.datavec.image.recordreader;
import com.google.common.base.Preconditions; import org.nd4j.shade.guava.base.Preconditions;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;

View File

@ -1,35 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.image.serde;
import org.datavec.api.transform.serde.legacy.GenericLegacyDeserializer;
import org.datavec.api.transform.serde.legacy.LegacyMappingHelper;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
public class LegacyImageMappingHelper {
@JsonDeserialize(using = LegacyImageTransformDeserializer.class)
public static class ImageTransformHelper { }
public static class LegacyImageTransformDeserializer extends GenericLegacyDeserializer<ImageTransform> {
public LegacyImageTransformDeserializer() {
super(ImageTransform.class, LegacyMappingHelper.getLegacyMappingImageTransform());
}
}
}

View File

@ -16,11 +16,8 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.datavec.image.serde.LegacyImageMappingHelper;
import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonSubTypes;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import java.util.Random; import java.util.Random;
@ -31,8 +28,7 @@ import java.util.Random;
* @author saudet * @author saudet
*/ */
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
defaultImpl = LegacyImageMappingHelper.ImageTransformHelper.class)
public interface ImageTransform { public interface ImageTransform {
/** /**

View File

@ -31,7 +31,6 @@
<properties> <properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<cleartk.version>2.0.0</cleartk.version> <cleartk.version>2.0.0</cleartk.version>
</properties> </properties>
<dependencies> <dependencies>
@ -75,13 +74,6 @@
<artifactId>cleartk-opennlp-tools</artifactId> <artifactId>cleartk-opennlp-tools</artifactId>
<version>${cleartk.version}</version> <version>${cleartk.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>

View File

@ -50,11 +50,6 @@
<artifactId>netty</artifactId> <artifactId>netty</artifactId>
<version>${netty.version}</version> <version>${netty.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId> <artifactId>commons-compress</artifactId>
@ -95,14 +90,6 @@
</exclusion> </exclusion>
</exclusions> </exclusions>
</dependency> </dependency>
<!-- Test dependencies -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -16,7 +16,7 @@
package org.datavec.hadoop.records.reader; package org.datavec.hadoop.records.reader;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*; import org.apache.hadoop.io.*;

View File

@ -16,7 +16,7 @@
package org.datavec.hadoop.records.reader; package org.datavec.hadoop.records.reader;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*; import org.apache.hadoop.io.*;

View File

@ -16,7 +16,7 @@
package org.datavec.hadoop.records.reader; package org.datavec.hadoop.records.reader;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*; import org.apache.hadoop.io.*;

View File

@ -16,7 +16,7 @@
package org.datavec.hadoop.records.writer; package org.datavec.hadoop.records.writer;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.datavec.api.records.converter.RecordReaderConverter; import org.datavec.api.records.converter.RecordReaderConverter;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.SequenceRecordReader;

View File

@ -50,12 +50,6 @@
<artifactId>protonpack</artifactId> <artifactId>protonpack</artifactId>
<version>${protonpack.version}</version> <version>${protonpack.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>

View File

@ -16,7 +16,7 @@
package org.datavec.local.transforms.join; package org.datavec.local.transforms.join;
import com.google.common.collect.Iterables; import org.nd4j.shade.guava.collect.Iterables;
import org.datavec.api.transform.join.Join; import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.functions.FlatMapFunctionAdapter; import org.datavec.local.transforms.functions.FlatMapFunctionAdapter;

View File

@ -52,12 +52,6 @@
<version>${project.version}</version> <version>${project.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>

View File

@ -26,10 +26,8 @@
<artifactId>datavec-python</artifactId> <artifactId>datavec-python</artifactId>
<dependencies> <dependencies>
<dependency>
<dependency>
<groupId>com.googlecode.json-simple</groupId> <groupId>com.googlecode.json-simple</groupId>
<artifactId>json-simple</artifactId> <artifactId>json-simple</artifactId>
<version>1.1</version> <version>1.1</version>
@ -39,11 +37,6 @@
<artifactId>cpython-platform</artifactId> <artifactId>cpython-platform</artifactId>
<version>${cpython-platform.version}</version> <version>${cpython-platform.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.google.code.findbugs</groupId> <groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId> <artifactId>jsr305</artifactId>
@ -54,14 +47,19 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>
<version>${logback.version}</version> <version>${logback.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-api</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
@ -70,5 +68,4 @@
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.1</id>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -30,12 +30,6 @@
<dependencies> <dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-jackson</artifactId> <artifactId>nd4j-jackson</artifactId>

View File

@ -35,12 +35,6 @@
<artifactId>datavec-api</artifactId> <artifactId>datavec-api</artifactId>
<version>${datavec.version}</version> <version>${datavec.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId> <artifactId>datavec-data-image</artifactId>

View File

@ -64,12 +64,6 @@
<version>${datavec.version}</version> <version>${datavec.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-cluster_2.11</artifactId>
<version>${akka.version}</version>
</dependency>
<dependency> <dependency>
<groupId>joda-time</groupId> <groupId>joda-time</groupId>
<artifactId>joda-time</artifactId> <artifactId>joda-time</artifactId>
@ -106,40 +100,10 @@
<version>${snakeyaml.version}</version> <version>${snakeyaml.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-java_2.11</artifactId> <artifactId>play-java_2.11</artifactId>
<version>${play.version}</version> <version>${playframework.version}</version>
<exclusions> <exclusions>
<exclusion> <exclusion>
<groupId>com.google.code.findbugs</groupId> <groupId>com.google.code.findbugs</groupId>
@ -161,25 +125,31 @@
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-json_2.11</artifactId> <artifactId>play-json_2.11</artifactId>
<version>${play.version}</version> <version>${playframework.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-server_2.11</artifactId> <artifactId>play-server_2.11</artifactId>
<version>${play.version}</version> <version>${playframework.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play_2.11</artifactId> <artifactId>play_2.11</artifactId>
<version>${play.version}</version> <version>${playframework.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>com.typesafe.play</groupId> <groupId>com.typesafe.play</groupId>
<artifactId>play-netty-server_2.11</artifactId> <artifactId>play-netty-server_2.11</artifactId>
<version>${play.version}</version> <version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-cluster_2.11</artifactId>
<version>2.5.23</version>
</dependency> </dependency>
<dependency> <dependency>
@ -195,14 +165,11 @@
<version>${jcommander.version}</version> <version>${jcommander.version}</version>
</dependency> </dependency>
<!-- Test Scope Dependencies -->
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.apache.spark</groupId>
<artifactId>nd4j-native</artifactId> <artifactId>spark-core_2.11</artifactId>
<version>${nd4j.version}</version> <version>${spark.version}</version>
<scope>test</scope>
</dependency> </dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -24,12 +24,16 @@ import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess; import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.transform.model.*; import org.datavec.spark.transform.model.*;
import play.BuiltInComponents;
import play.Mode; import play.Mode;
import play.routing.Router;
import play.routing.RoutingDsl; import play.routing.RoutingDsl;
import play.server.Server; import play.server.Server;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Base64;
import java.util.Random;
import static play.mvc.Results.*; import static play.mvc.Results.*;
@ -66,9 +70,6 @@ public class CSVSparkTransformServer extends SparkTransformServer {
System.exit(1); System.exit(1);
} }
RoutingDsl routingDsl = new RoutingDsl();
if (jsonPath != null) { if (jsonPath != null) {
String json = FileUtils.readFileToString(new File(jsonPath)); String json = FileUtils.readFileToString(new File(jsonPath));
TransformProcess transformProcess = TransformProcess.fromJson(json); TransformProcess transformProcess = TransformProcess.fromJson(json);
@ -78,8 +79,26 @@ public class CSVSparkTransformServer extends SparkTransformServer {
+ "to /transformprocess"); + "to /transformprocess");
} }
//Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret
String crypto = System.getProperty("play.crypto.secret");
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
byte[] newCrypto = new byte[1024];
routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> { new Random().nextBytes(newCrypto);
String base64 = Base64.getEncoder().encodeToString(newCrypto);
System.setProperty("play.crypto.secret", base64);
}
server = Server.forRouter(Mode.PROD, port, this::createRouter);
}
protected Router createRouter(BuiltInComponents b){
RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
routingDsl.GET("/transformprocess").routingTo(req -> {
try { try {
if (transform == null) if (transform == null)
return badRequest(); return badRequest();
@ -88,11 +107,11 @@ public class CSVSparkTransformServer extends SparkTransformServer {
log.error("Error in GET /transformprocess",e); log.error("Error in GET /transformprocess",e);
return internalServerError(e.getMessage()); return internalServerError(e.getMessage());
} }
}))); });
routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transformprocess").routingTo(req -> {
try { try {
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText()); TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
setCSVTransformProcess(transformProcess); setCSVTransformProcess(transformProcess);
log.info("Transform process initialized"); log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
@ -100,12 +119,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
log.error("Error in POST /transformprocess",e); log.error("Error in POST /transformprocess",e);
return internalServerError(e.getMessage()); return internalServerError(e.getMessage());
} }
}))); });
routingDsl.POST("/transformincremental").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transformincremental").routingTo(req -> {
if (isSequence()) { if (isSequence(req)) {
try { try {
BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (record == null) if (record == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
@ -115,7 +134,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
} }
} else { } else {
try { try {
SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class); SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
if (record == null) if (record == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
@ -124,12 +143,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage()); return internalServerError(e.getMessage());
} }
} }
}))); });
routingDsl.POST("/transform").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transform").routingTo(req -> {
if (isSequence()) { if (isSequence(req)) {
try { try {
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class)); SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
if (batch == null) if (batch == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(batch)).as(contentType); return ok(objectMapper.writeValueAsString(batch)).as(contentType);
@ -139,7 +158,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
} }
} else { } else {
try { try {
BatchCSVRecord input = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
BatchCSVRecord batch = transform(input); BatchCSVRecord batch = transform(input);
if (batch == null) if (batch == null)
return badRequest(); return badRequest();
@ -149,14 +168,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage()); return internalServerError(e.getMessage());
} }
} }
});
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
}))); if (isSequence(req)) {
routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> {
if (isSequence()) {
try { try {
BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (record == null) if (record == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
@ -166,7 +183,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
} }
} else { } else {
try { try {
SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class); SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
if (record == null) if (record == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
@ -175,13 +192,12 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage()); return internalServerError(e.getMessage());
} }
} }
});
}))); routingDsl.POST("/transformarray").routingTo(req -> {
if (isSequence(req)) {
routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> {
if (isSequence()) {
try { try {
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class); SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
if (batchCSVRecord == null) if (batchCSVRecord == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType); return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
@ -191,7 +207,7 @@ public class CSVSparkTransformServer extends SparkTransformServer {
} }
} else { } else {
try { try {
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (batchCSVRecord == null) if (batchCSVRecord == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType); return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
@ -200,10 +216,9 @@ public class CSVSparkTransformServer extends SparkTransformServer {
return internalServerError(e.getMessage()); return internalServerError(e.getMessage());
} }
} }
}))); });
return routingDsl.build();
server = Server.forRouter(routingDsl.build(), Mode.PROD, port);
} }
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {

View File

@ -1,41 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import play.libs.F;
import play.mvc.Result;
import java.util.function.Function;
import java.util.function.Supplier;
/**
* Utility methods for Routing
*
* @author Alex Black
*/
public class FunctionUtil {
public static F.Function0<Result> function0(Supplier<Result> supplier) {
return supplier::get;
}
public static <T> F.Function<T, Result> function(Function<T, Result> function) {
return function::apply;
}
}

View File

@ -24,8 +24,11 @@ import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess; import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.transform.model.*; import org.datavec.spark.transform.model.*;
import play.BuiltInComponents;
import play.Mode; import play.Mode;
import play.libs.Files;
import play.mvc.Http; import play.mvc.Http;
import play.routing.Router;
import play.routing.RoutingDsl; import play.routing.RoutingDsl;
import play.server.Server; import play.server.Server;
@ -33,6 +36,7 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.Function;
import static play.mvc.Controller.request; import static play.mvc.Controller.request;
import static play.mvc.Results.*; import static play.mvc.Results.*;
@ -62,8 +66,6 @@ public class ImageSparkTransformServer extends SparkTransformServer {
System.exit(1); System.exit(1);
} }
RoutingDsl routingDsl = new RoutingDsl();
if (jsonPath != null) { if (jsonPath != null) {
String json = FileUtils.readFileToString(new File(jsonPath)); String json = FileUtils.readFileToString(new File(jsonPath));
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json); ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
@ -73,7 +75,13 @@ public class ImageSparkTransformServer extends SparkTransformServer {
+ "to /transformprocess"); + "to /transformprocess");
} }
routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> { server = Server.forRouter(Mode.PROD, port, this::createRouter);
}
protected Router createRouter(BuiltInComponents builtInComponents){
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
routingDsl.GET("/transformprocess").routingTo(req -> {
try { try {
if (transform == null) if (transform == null)
return badRequest(); return badRequest();
@ -83,11 +91,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace(); e.printStackTrace();
return internalServerError(); return internalServerError();
} }
}))); });
routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transformprocess").routingTo(req -> {
try { try {
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText()); ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
setImageTransformProcess(transformProcess); setImageTransformProcess(transformProcess);
log.info("Transform process initialized"); log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
@ -95,11 +103,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace(); e.printStackTrace();
return internalServerError(); return internalServerError();
} }
}))); });
routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transformincrementalarray").routingTo(req -> {
try { try {
SingleImageRecord record = objectMapper.readValue(getJsonText(), SingleImageRecord.class); SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
if (record == null) if (record == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
@ -107,17 +115,17 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace(); e.printStackTrace();
return internalServerError(); return internalServerError();
} }
}))); });
routingDsl.POST("/transformincrementalimage").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transformincrementalimage").routingTo(req -> {
try { try {
Http.MultipartFormData body = request().body().asMultipartFormData(); Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
List<Http.MultipartFormData.FilePart> files = body.getFiles(); List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
if (files.size() == 0 || files.get(0).getFile() == null) { if (files.isEmpty() || files.get(0).getRef() == null ) {
return badRequest(); return badRequest();
} }
File file = files.get(0).getFile(); File file = files.get(0).getRef().path().toFile();
SingleImageRecord record = new SingleImageRecord(file.toURI()); SingleImageRecord record = new SingleImageRecord(file.toURI());
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
@ -125,11 +133,11 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace(); e.printStackTrace();
return internalServerError(); return internalServerError();
} }
}))); });
routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transformarray").routingTo(req -> {
try { try {
BatchImageRecord batch = objectMapper.readValue(getJsonText(), BatchImageRecord.class); BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
if (batch == null) if (batch == null)
return badRequest(); return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
@ -137,22 +145,22 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace(); e.printStackTrace();
return internalServerError(); return internalServerError();
} }
}))); });
routingDsl.POST("/transformimage").routeTo(FunctionUtil.function0((() -> { routingDsl.POST("/transformimage").routingTo(req -> {
try { try {
Http.MultipartFormData body = request().body().asMultipartFormData(); Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
List<Http.MultipartFormData.FilePart> files = body.getFiles(); List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
if (files.size() == 0) { if (files.size() == 0) {
return badRequest(); return badRequest();
} }
List<SingleImageRecord> records = new ArrayList<>(); List<SingleImageRecord> records = new ArrayList<>();
for (Http.MultipartFormData.FilePart filePart : files) { for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) {
File file = filePart.getFile(); Files.TemporaryFile file = filePart.getRef();
if (file != null) { if (file != null) {
SingleImageRecord record = new SingleImageRecord(file.toURI()); SingleImageRecord record = new SingleImageRecord(file.path().toUri());
records.add(record); records.add(record);
} }
} }
@ -164,9 +172,9 @@ public class ImageSparkTransformServer extends SparkTransformServer {
e.printStackTrace(); e.printStackTrace();
return internalServerError(); return internalServerError();
} }
}))); });
server = Server.forRouter(routingDsl.build(), Mode.PROD, port); return routingDsl.build();
} }
@Override @Override

View File

@ -22,6 +22,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody;
import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.model.BatchCSVRecord;
import org.datavec.spark.transform.service.DataVecTransformService; import org.datavec.spark.transform.service.DataVecTransformService;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import play.mvc.Http;
import play.server.Server; import play.server.Server;
import static play.mvc.Controller.request; import static play.mvc.Controller.request;
@ -50,25 +51,17 @@ public abstract class SparkTransformServer implements DataVecTransformService {
server.stop(); server.stop();
} }
protected boolean isSequence() { protected boolean isSequence(Http.Request request) {
return request().hasHeader(SEQUENCE_OR_NOT_HEADER) return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
&& request().getHeader(SEQUENCE_OR_NOT_HEADER).toUpperCase() && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
.equals("TRUE");
} }
protected String getJsonText(Http.Request request) {
protected String getHeaderValue(String value) { JsonNode tryJson = request.body().asJson();
if (request().hasHeader(value))
return request().getHeader(value);
return null;
}
protected String getJsonText() {
JsonNode tryJson = request().body().asJson();
if (tryJson != null) if (tryJson != null)
return tryJson.toString(); return tryJson.toString();
else else
return request().body().asText(); return request.body().asText();
} }
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);

View File

@ -0,0 +1,350 @@
# This is the main configuration file for the application.
# https://www.playframework.com/documentation/latest/ConfigFile
# ~~~~~
# Play uses HOCON as its configuration file format. HOCON has a number
# of advantages over other config formats, but there are two things that
# can be used when modifying settings.
#
# You can include other configuration files in this main application.conf file:
#include "extra-config.conf"
#
# You can declare variables and substitute for them:
#mykey = ${some.value}
#
# And if an environment variable exists when there is no other subsitution, then
# HOCON will fall back to substituting environment variable:
#mykey = ${JAVA_HOME}
## Akka
# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
# ~~~~~
# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
# other streaming HTTP responses.
akka {
# "akka.log-config-on-start" is extraordinarly useful because it log the complete
# configuration at INFO level, including defaults and overrides, so it s worth
# putting at the very top.
#
# Put the following in your conf/logback.xml file:
#
# <logger name="akka.actor" level="INFO" />
#
# And then uncomment this line to debug the configuration.
#
#log-config-on-start = true
}
## Modules
# https://www.playframework.com/documentation/latest/Modules
# ~~~~~
# Control which modules are loaded when Play starts. Note that modules are
# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
# Please see https://www.playframework.com/documentation/latest/GlobalSettings
# for more information.
#
# You can also extend Play functionality by using one of the publically available
# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
play.modules {
# By default, Play will load any class called Module that is defined
# in the root package (the "app" directory), or you can define them
# explicitly below.
# If there are any built-in modules that you want to disable, you can list them here.
#enabled += my.application.Module
# If there are any built-in modules that you want to disable, you can list them here.
#disabled += ""
}
## Internationalisation
# https://www.playframework.com/documentation/latest/JavaI18N
# https://www.playframework.com/documentation/latest/ScalaI18N
# ~~~~~
# Play comes with its own i18n settings, which allow the user's preferred language
# to map through to internal messages, or allow the language to be stored in a cookie.
play.i18n {
# The application languages
langs = [ "en" ]
# Whether the language cookie should be secure or not
#langCookieSecure = true
# Whether the HTTP only attribute of the cookie should be set to true
#langCookieHttpOnly = true
}
## Play HTTP settings
# ~~~~~
play.http {
## Router
# https://www.playframework.com/documentation/latest/JavaRouting
# https://www.playframework.com/documentation/latest/ScalaRouting
# ~~~~~
# Define the Router object to use for this application.
# This router will be looked up first when the application is starting up,
# so make sure this is the entry point.
# Furthermore, it's assumed your route file is named properly.
# So for an application router like `my.application.Router`,
# you may need to define a router file `conf/my.application.routes`.
# Default to Routes in the root package (aka "apps" folder) (and conf/routes)
#router = my.application.Router
## Action Creator
# https://www.playframework.com/documentation/latest/JavaActionCreator
# ~~~~~
#actionCreator = null
## ErrorHandler
# https://www.playframework.com/documentation/latest/JavaRouting
# https://www.playframework.com/documentation/latest/ScalaRouting
# ~~~~~
# If null, will attempt to load a class called ErrorHandler in the root package,
#errorHandler = null
## Filters
# https://www.playframework.com/documentation/latest/ScalaHttpFilters
# https://www.playframework.com/documentation/latest/JavaHttpFilters
# ~~~~~
# Filters run code on every request. They can be used to perform
# common logic for all your actions, e.g. adding common headers.
# Defaults to "Filters" in the root package (aka "apps" folder)
# Alternatively you can explicitly register a class here.
#filters += my.application.Filters
## Session & Flash
# https://www.playframework.com/documentation/latest/JavaSessionFlash
# https://www.playframework.com/documentation/latest/ScalaSessionFlash
# ~~~~~
session {
# Sets the cookie to be sent only over HTTPS.
#secure = true
# Sets the cookie to be accessed only by the server.
#httpOnly = true
# Sets the max-age field of the cookie to 5 minutes.
# NOTE: this only sets when the browser will discard the cookie. Play will consider any
# cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
# you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
#maxAge = 300
# Sets the domain on the session cookie.
#domain = "example.com"
}
flash {
# Sets the cookie to be sent only over HTTPS.
#secure = true
# Sets the cookie to be accessed only by the server.
#httpOnly = true
}
}
## Netty Provider
# https://www.playframework.com/documentation/latest/SettingsNetty
# ~~~~~
play.server.netty {
# Whether the Netty wire should be logged
#log.wire = true
# If you run Play on Linux, you can use Netty's native socket transport
# for higher performance with less garbage.
#transport = "native"
}
## WS (HTTP Client)
# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
# ~~~~~
# The HTTP client primarily used for REST APIs. The default client can be
# configured directly, but you can also create different client instances
# with customized settings. You must enable this by adding to build.sbt:
#
# libraryDependencies += ws // or javaWs if using java
#
play.ws {
# Sets HTTP requests not to follow 302 requests
#followRedirects = false
# Sets the maximum number of open HTTP connections for the client.
#ahc.maxConnectionsTotal = 50
## WS SSL
# https://www.playframework.com/documentation/latest/WsSSL
# ~~~~~
ssl {
# Configuring HTTPS with Play WS does not require programming. You can
# set up both trustManager and keyManager for mutual authentication, and
# turn on JSSE debugging in development with a reload.
#debug.handshake = true
#trustManager = {
# stores = [
# { type = "JKS", path = "exampletrust.jks" }
# ]
#}
}
}
## Cache
# https://www.playframework.com/documentation/latest/JavaCache
# https://www.playframework.com/documentation/latest/ScalaCache
# ~~~~~
# Play comes with an integrated cache API that can reduce the operational
# overhead of repeated requests. You must enable this by adding to build.sbt:
#
# libraryDependencies += cache
#
play.cache {
# If you want to bind several caches, you can bind the individually
#bindCaches = ["db-cache", "user-cache", "session-cache"]
}
## Filters
# https://www.playframework.com/documentation/latest/Filters
# ~~~~~
# There are a number of built-in filters that can be enabled and configured
# to give Play greater security. You must enable this by adding to build.sbt:
#
# libraryDependencies += filters
#
play.filters {
## CORS filter configuration
# https://www.playframework.com/documentation/latest/CorsFilter
# ~~~~~
# CORS is a protocol that allows web applications to make requests from the browser
# across different domains.
# NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
# dependencies on CORS settings.
cors {
# Filter paths by a whitelist of path prefixes
#pathPrefixes = ["/some/path", ...]
# The allowed origins. If null, all origins are allowed.
#allowedOrigins = ["http://www.example.com"]
# The allowed HTTP methods. If null, all methods are allowed
#allowedHttpMethods = ["GET", "POST"]
}
## CSRF Filter
# https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
# https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
# ~~~~~
# Play supports multiple methods for verifying that a request is not a CSRF request.
# The primary mechanism is a CSRF token. This token gets placed either in the query string
# or body of every form submitted, and also gets placed in the users session.
# Play then verifies that both tokens are present and match.
csrf {
# Sets the cookie to be sent only over HTTPS
#cookie.secure = true
# Defaults to CSRFErrorHandler in the root package.
#errorHandler = MyCSRFErrorHandler
}
## Security headers filter configuration
# https://www.playframework.com/documentation/latest/SecurityHeaders
# ~~~~~
# Defines security headers that prevent XSS attacks.
# If enabled, then all options are set to the below configuration by default:
headers {
# The X-Frame-Options header. If null, the header is not set.
#frameOptions = "DENY"
# The X-XSS-Protection header. If null, the header is not set.
#xssProtection = "1; mode=block"
# The X-Content-Type-Options header. If null, the header is not set.
#contentTypeOptions = "nosniff"
# The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
#permittedCrossDomainPolicies = "master-only"
# The Content-Security-Policy header. If null, the header is not set.
#contentSecurityPolicy = "default-src 'self'"
}
## Allowed hosts filter configuration
# https://www.playframework.com/documentation/latest/AllowedHostsFilter
# ~~~~~
# Play provides a filter that lets you configure which hosts can access your application.
# This is useful to prevent cache poisoning attacks.
hosts {
# Allow requests to example.com, its subdomains, and localhost:9000.
#allowed = [".example.com", "localhost:9000"]
}
}
## Evolutions
# https://www.playframework.com/documentation/latest/Evolutions
# ~~~~~
# Evolutions allows database scripts to be automatically run on startup in dev mode
# for database migrations. You must enable this by adding to build.sbt:
#
# libraryDependencies += evolutions
#
play.evolutions {
# You can disable evolutions for a specific datasource if necessary
#db.default.enabled = false
}
## Database Connection Pool
# https://www.playframework.com/documentation/latest/SettingsJDBC
# ~~~~~
# Play doesn't require a JDBC database to run, but you can easily enable one.
#
# libraryDependencies += jdbc
#
play.db {
# The combination of these two settings results in "db.default" as the
# default JDBC pool:
#config = "db"
#default = "default"
# Play uses HikariCP as the default connection pool. You can override
# settings by changing the prototype:
prototype {
# Sets a fixed JDBC connection pool size of 50
#hikaricp.minimumIdle = 50
#hikaricp.maximumPoolSize = 50
}
}
## JDBC Datasource
# https://www.playframework.com/documentation/latest/JavaDatabase
# https://www.playframework.com/documentation/latest/ScalaDatabase
# ~~~~~
# Once JDBC datasource is set up, you can work with several different
# database options:
#
# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
# EBean: https://playframework.com/documentation/latest/JavaEbean
# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
#
db {
# You can declare as many datasources as you want.
# By convention, the default datasource is named `default`
# https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
default.driver = org.h2.Driver
default.url = "jdbc:h2:mem:play"
#default.username = sa
#default.password = ""
# You can expose this datasource via JNDI if needed (Useful for JPA)
default.jndiName=DefaultDS
# You can turn on SQL logging for any datasource
# https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
#default.logSql=true
}
jpa.default=defaultPersistenceUnit
#Increase default maximum post length - used for remote listener functionality
#Can get response 413 with larger networks without setting this
# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
#parsers.text.maxLength=10M
play.http.parser.maxMemoryBuffer=10M

View File

@ -28,61 +28,11 @@
<artifactId>datavec-spark_2.11</artifactId> <artifactId>datavec-spark_2.11</artifactId>
<properties> <properties>
<!-- These Spark versions have to be here, and NOT in the datavec parent pom. Otherwise, the properties
will be resolved to their defaults. Whereas when defined here (the version outputStream spark 1 vs. 2 specific) we can
have different version properties simultaneously (in different pom files) instead of a single global property -->
<spark.version>2.1.0</spark.version>
<spark.major.version>2</spark.major.version>
<!-- Default scala versions, may be overwritten by build profiles --> <!-- Default scala versions, may be overwritten by build profiles -->
<scala.version>2.11.12</scala.version> <scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version> <scala.binary.version>2.11</scala.binary.version>
</properties> </properties>
<build>
<plugins>
<!-- added source folder containing the code specific to the spark version -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-source</id>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>src/main/spark-${spark.major.version}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jsr310</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-yaml</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.11</artifactId>
<version>${jackson.version}</version>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.scala-lang</groupId> <groupId>org.scala-lang</groupId>
@ -95,42 +45,13 @@
<version>${scala.version}</version> <version>${scala.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-core-asl</artifactId>
<version>${jackson-asl.version}</version>
</dependency>
<dependency>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
<version>${jackson-asl.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId> <artifactId>spark-sql_2.11</artifactId>
<version>${spark.version}</version> <version>${spark.version}</version>
<scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
</dependency>
<dependency>
<groupId>com.google.inject</groupId>
<artifactId>guice</artifactId>
<version>${guice.version}</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${google.protobuf.version}</version>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>${commons-codec.version}</version>
</dependency>
<dependency> <dependency>
<groupId>commons-collections</groupId> <groupId>commons-collections</groupId>
<artifactId>commons-collections</artifactId> <artifactId>commons-collections</artifactId>
@ -141,96 +62,16 @@
<artifactId>commons-io</artifactId> <artifactId>commons-io</artifactId>
<version>${commons-io.version}</version> <version>${commons-io.version}</version>
</dependency> </dependency>
<dependency>
<groupId>commons-lang</groupId>
<artifactId>commons-lang</artifactId>
<version>${commons-lang.version}</version>
</dependency>
<dependency>
<groupId>commons-net</groupId>
<artifactId>commons-net</artifactId>
<version>${commons-net.version}</version>
</dependency>
<dependency>
<groupId>com.sun.xml.bind</groupId>
<artifactId>jaxb-core</artifactId>
<version>${jaxb.version}</version>
</dependency>
<dependency>
<groupId>com.sun.xml.bind</groupId>
<artifactId>jaxb-impl</artifactId>
<version>${jaxb.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_2.11</artifactId>
<version>${akka.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-remote_2.11</artifactId>
<version>${akka.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-slf4j_2.11</artifactId>
<version>${akka.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>${servlet.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId>
<version>${commons-compress.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>${commons-lang3.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId> <artifactId>commons-math3</artifactId>
<version>${commons-math3.version}</version> <version>${commons-math3.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.apache.curator</groupId>
<artifactId>curator-recipes</artifactId>
<version>${curator.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.slf4j</groupId> <groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId> <artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version> <version>${slf4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.typesafe</groupId>
<artifactId>config</artifactId>
<version>${typesafe.config.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.spark</groupId> <groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId> <artifactId>spark-core_2.11</artifactId>
@ -241,14 +82,6 @@
<groupId>com.google.code.findbugs</groupId> <groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId> <artifactId>jsr305</artifactId>
</exclusion> </exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</exclusion>
<exclusion>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</exclusion>
</exclusions> </exclusions>
</dependency> </dependency>
@ -281,13 +114,6 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.datavec</groupId> <groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId> <artifactId>datavec-local</artifactId>

View File

@ -21,10 +21,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.Function2;
import org.apache.spark.sql.Column; import org.apache.spark.sql.*;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
@ -46,7 +43,6 @@ import java.util.List;
import static org.apache.spark.sql.functions.avg; import static org.apache.spark.sql.functions.avg;
import static org.apache.spark.sql.functions.col; import static org.apache.spark.sql.functions.col;
import static org.datavec.spark.transform.DataRowsFacade.dataRows;
/** /**
@ -71,7 +67,7 @@ public class DataFrames {
* deviation for * deviation for
* @return the column that represents the standard deviation * @return the column that represents the standard deviation
*/ */
public static Column std(DataRowsFacade dataFrame, String columnName) { public static Column std(Dataset<Row> dataFrame, String columnName) {
return functions.sqrt(var(dataFrame, columnName)); return functions.sqrt(var(dataFrame, columnName));
} }
@ -85,8 +81,8 @@ public class DataFrames {
* deviation for * deviation for
* @return the column that represents the standard deviation * @return the column that represents the standard deviation
*/ */
public static Column var(DataRowsFacade dataFrame, String columnName) { public static Column var(Dataset<Row> dataFrame, String columnName) {
return dataFrame.get().groupBy(columnName).agg(functions.variance(columnName)).col(columnName); return dataFrame.groupBy(columnName).agg(functions.variance(columnName)).col(columnName);
} }
/** /**
@ -97,8 +93,8 @@ public class DataFrames {
* @param columnName the name of the column to get the min for * @param columnName the name of the column to get the min for
* @return the column that represents the min * @return the column that represents the min
*/ */
public static Column min(DataRowsFacade dataFrame, String columnName) { public static Column min(Dataset<Row> dataFrame, String columnName) {
return dataFrame.get().groupBy(columnName).agg(functions.min(columnName)).col(columnName); return dataFrame.groupBy(columnName).agg(functions.min(columnName)).col(columnName);
} }
/** /**
@ -110,8 +106,8 @@ public class DataFrames {
* to get the max for * to get the max for
* @return the column that represents the max * @return the column that represents the max
*/ */
public static Column max(DataRowsFacade dataFrame, String columnName) { public static Column max(Dataset<Row> dataFrame, String columnName) {
return dataFrame.get().groupBy(columnName).agg(functions.max(columnName)).col(columnName); return dataFrame.groupBy(columnName).agg(functions.max(columnName)).col(columnName);
} }
/** /**
@ -122,8 +118,8 @@ public class DataFrames {
* @param columnName the name of the column to get the mean for * @param columnName the name of the column to get the mean for
* @return the column that represents the mean * @return the column that represents the mean
*/ */
public static Column mean(DataRowsFacade dataFrame, String columnName) { public static Column mean(Dataset<Row> dataFrame, String columnName) {
return dataFrame.get().groupBy(columnName).agg(avg(columnName)).col(columnName); return dataFrame.groupBy(columnName).agg(avg(columnName)).col(columnName);
} }
/** /**
@ -166,7 +162,7 @@ public class DataFrames {
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
* of this record in the original time series.<br> * of this record in the original time series.<br>
* These two columns are required if the data is to be converted back into a sequence at a later point, for example * These two columns are required if the data is to be converted back into a sequence at a later point, for example
* using {@link #toRecordsSequence(DataRowsFacade)} * using {@link #toRecordsSequence(Dataset<Row>)}
* *
* @param schema Schema to convert * @param schema Schema to convert
* @return StructType for the schema * @return StructType for the schema
@ -250,9 +246,9 @@ public class DataFrames {
* @param dataFrame the dataframe to convert * @param dataFrame the dataframe to convert
* @return the converted schema and rdd of writables * @return the converted schema and rdd of writables
*/ */
public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(DataRowsFacade dataFrame) { public static Pair<Schema, JavaRDD<List<Writable>>> toRecords(Dataset<Row> dataFrame) {
Schema schema = fromStructType(dataFrame.get().schema()); Schema schema = fromStructType(dataFrame.schema());
return new Pair<>(schema, dataFrame.get().javaRDD().map(new ToRecord(schema))); return new Pair<>(schema, dataFrame.javaRDD().map(new ToRecord(schema)));
} }
/** /**
@ -267,11 +263,11 @@ public class DataFrames {
* @param dataFrame Data frame to convert * @param dataFrame Data frame to convert
* @return Data in sequence (i.e., {@code List<List<Writable>>} form * @return Data in sequence (i.e., {@code List<List<Writable>>} form
*/ */
public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(DataRowsFacade dataFrame) { public static Pair<Schema, JavaRDD<List<List<Writable>>>> toRecordsSequence(Dataset<Row> dataFrame) {
//Need to convert from flattened to sequence data... //Need to convert from flattened to sequence data...
//First: Group by the Sequence UUID (first column) //First: Group by the Sequence UUID (first column)
JavaPairRDD<String, Iterable<Row>> grouped = dataFrame.get().javaRDD().groupBy(new Function<Row, String>() { JavaPairRDD<String, Iterable<Row>> grouped = dataFrame.javaRDD().groupBy(new Function<Row, String>() {
@Override @Override
public String call(Row row) throws Exception { public String call(Row row) throws Exception {
return row.getString(0); return row.getString(0);
@ -279,7 +275,7 @@ public class DataFrames {
}); });
Schema schema = fromStructType(dataFrame.get().schema()); Schema schema = fromStructType(dataFrame.schema());
//Group by sequence UUID, and sort each row within the sequences using the time step index //Group by sequence UUID, and sort each row within the sequences using the time step index
Function<Iterable<Row>, List<List<Writable>>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner Function<Iterable<Row>, List<List<Writable>>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner
@ -318,11 +314,11 @@ public class DataFrames {
* @param data the data to convert * @param data the data to convert
* @return the dataframe object * @return the dataframe object
*/ */
public static DataRowsFacade toDataFrame(Schema schema, JavaRDD<List<Writable>> data) { public static Dataset<Row> toDataFrame(Schema schema, JavaRDD<List<Writable>> data) {
JavaSparkContext sc = new JavaSparkContext(data.context()); JavaSparkContext sc = new JavaSparkContext(data.context());
SQLContext sqlContext = new SQLContext(sc); SQLContext sqlContext = new SQLContext(sc);
JavaRDD<Row> rows = data.map(new ToRow(schema)); JavaRDD<Row> rows = data.map(new ToRow(schema));
return dataRows(sqlContext.createDataFrame(rows, fromSchema(schema))); return sqlContext.createDataFrame(rows, fromSchema(schema));
} }
@ -333,18 +329,18 @@ public class DataFrames {
* - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position
* of this record in the original time series.<br> * of this record in the original time series.<br>
* These two columns are required if the data is to be converted back into a sequence at a later point, for example * These two columns are required if the data is to be converted back into a sequence at a later point, for example
* using {@link #toRecordsSequence(DataRowsFacade)} * using {@link #toRecordsSequence(Dataset<Row>)}
* *
* @param schema Schema for the data * @param schema Schema for the data
* @param data Sequence data to convert to a DataFrame * @param data Sequence data to convert to a DataFrame
* @return The dataframe object * @return The dataframe object
*/ */
public static DataRowsFacade toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) { public static Dataset<Row> toDataFrameSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
JavaSparkContext sc = new JavaSparkContext(data.context()); JavaSparkContext sc = new JavaSparkContext(data.context());
SQLContext sqlContext = new SQLContext(sc); SQLContext sqlContext = new SQLContext(sc);
JavaRDD<Row> rows = data.flatMap(new SequenceToRows(schema)); JavaRDD<Row> rows = data.flatMap(new SequenceToRows(schema));
return dataRows(sqlContext.createDataFrame(rows, fromSchemaSequence(schema))); return sqlContext.createDataFrame(rows, fromSchemaSequence(schema));
} }
/** /**

View File

@ -19,14 +19,13 @@ package org.datavec.spark.transform;
import org.apache.commons.collections.map.ListOrderedMap; import org.apache.commons.collections.map.ListOrderedMap;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Column; import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import java.util.*; import java.util.*;
import static org.datavec.spark.transform.DataRowsFacade.dataRows;
/** /**
* Simple dataframe based normalization. * Simple dataframe based normalization.
@ -46,7 +45,7 @@ public class Normalization {
* @return a zero mean unit variance centered * @return a zero mean unit variance centered
* rdd * rdd
*/ */
public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame) { public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame) {
return zeromeanUnitVariance(frame, Collections.<String>emptyList()); return zeromeanUnitVariance(frame, Collections.<String>emptyList());
} }
@ -71,7 +70,7 @@ public class Normalization {
* @param max the maximum value * @param max the maximum value
* @return the normalized dataframe per column * @return the normalized dataframe per column
*/ */
public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max) { public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max) {
return normalize(dataFrame, min, max, Collections.<String>emptyList()); return normalize(dataFrame, min, max, Collections.<String>emptyList());
} }
@ -86,7 +85,7 @@ public class Normalization {
*/ */
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min,
double max) { double max) {
DataRowsFacade frame = DataFrames.toDataFrame(schema, data); Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
return DataFrames.toRecords(normalize(frame, min, max, Collections.<String>emptyList())).getSecond(); return DataFrames.toRecords(normalize(frame, min, max, Collections.<String>emptyList())).getSecond();
} }
@ -97,7 +96,7 @@ public class Normalization {
* @param dataFrame the dataframe to scale * @param dataFrame the dataframe to scale
* @return the normalized dataframe per column * @return the normalized dataframe per column
*/ */
public static DataRowsFacade normalize(DataRowsFacade dataFrame) { public static Dataset<Row> normalize(Dataset<Row> dataFrame) {
return normalize(dataFrame, 0, 1, Collections.<String>emptyList()); return normalize(dataFrame, 0, 1, Collections.<String>emptyList());
} }
@ -120,8 +119,8 @@ public class Normalization {
* @return a zero mean unit variance centered * @return a zero mean unit variance centered
* rdd * rdd
*/ */
public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame, List<String> skipColumns) { public static Dataset<Row> zeromeanUnitVariance(Dataset<Row> frame, List<String> skipColumns) {
List<String> columnsList = DataFrames.toList(frame.get().columns()); List<String> columnsList = DataFrames.toList(frame.columns());
columnsList.removeAll(skipColumns); columnsList.removeAll(skipColumns);
String[] columnNames = DataFrames.toArray(columnsList); String[] columnNames = DataFrames.toArray(columnsList);
//first row is std second row is mean, each column in a row is for a particular column //first row is std second row is mean, each column in a row is for a particular column
@ -133,7 +132,7 @@ public class Normalization {
if (std == 0.0) if (std == 0.0)
std = 1; //All same value -> (x-x)/1 = 0 std = 1; //All same value -> (x-x)/1 = 0
frame = dataRows(frame.get().withColumn(columnName, frame.get().col(columnName).minus(mean).divide(std))); frame = frame.withColumn(columnName, frame.col(columnName).minus(mean).divide(std));
} }
@ -152,7 +151,7 @@ public class Normalization {
*/ */
public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> data, public static JavaRDD<List<Writable>> zeromeanUnitVariance(Schema schema, JavaRDD<List<Writable>> data,
List<String> skipColumns) { List<String> skipColumns) {
DataRowsFacade frame = DataFrames.toDataFrame(schema, data); Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond(); return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond();
} }
@ -178,7 +177,7 @@ public class Normalization {
*/ */
public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema, public static JavaRDD<List<List<Writable>>> zeroMeanUnitVarianceSequence(Schema schema,
JavaRDD<List<List<Writable>>> sequence, List<String> excludeColumns) { JavaRDD<List<List<Writable>>> sequence, List<String> excludeColumns) {
DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, sequence); Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, sequence);
if (excludeColumns == null) if (excludeColumns == null)
excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN); excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN);
else { else {
@ -196,7 +195,7 @@ public class Normalization {
* @param columns the columns to get the * @param columns the columns to get the
* @return * @return
*/ */
public static List<Row> minMaxColumns(DataRowsFacade data, List<String> columns) { public static List<Row> minMaxColumns(Dataset<Row> data, List<String> columns) {
String[] arr = new String[columns.size()]; String[] arr = new String[columns.size()];
for (int i = 0; i < arr.length; i++) for (int i = 0; i < arr.length; i++)
arr[i] = columns.get(i); arr[i] = columns.get(i);
@ -210,7 +209,7 @@ public class Normalization {
* @param columns the columns to get the * @param columns the columns to get the
* @return * @return
*/ */
public static List<Row> minMaxColumns(DataRowsFacade data, String... columns) { public static List<Row> minMaxColumns(Dataset<Row> data, String... columns) {
return aggregate(data, columns, new String[] {"min", "max"}); return aggregate(data, columns, new String[] {"min", "max"});
} }
@ -221,7 +220,7 @@ public class Normalization {
* @param columns the columns to get the * @param columns the columns to get the
* @return * @return
*/ */
public static List<Row> stdDevMeanColumns(DataRowsFacade data, List<String> columns) { public static List<Row> stdDevMeanColumns(Dataset<Row> data, List<String> columns) {
String[] arr = new String[columns.size()]; String[] arr = new String[columns.size()];
for (int i = 0; i < arr.length; i++) for (int i = 0; i < arr.length; i++)
arr[i] = columns.get(i); arr[i] = columns.get(i);
@ -237,7 +236,7 @@ public class Normalization {
* @param columns the columns to get the * @param columns the columns to get the
* @return * @return
*/ */
public static List<Row> stdDevMeanColumns(DataRowsFacade data, String... columns) { public static List<Row> stdDevMeanColumns(Dataset<Row> data, String... columns) {
return aggregate(data, columns, new String[] {"stddev", "mean"}); return aggregate(data, columns, new String[] {"stddev", "mean"});
} }
@ -251,7 +250,7 @@ public class Normalization {
* Each row will be a function with the desired columnar output * Each row will be a function with the desired columnar output
* in the order in which the columns were specified. * in the order in which the columns were specified.
*/ */
public static List<Row> aggregate(DataRowsFacade data, String[] columns, String[] functions) { public static List<Row> aggregate(Dataset<Row> data, String[] columns, String[] functions) {
String[] rest = new String[columns.length - 1]; String[] rest = new String[columns.length - 1];
System.arraycopy(columns, 1, rest, 0, rest.length); System.arraycopy(columns, 1, rest, 0, rest.length);
List<Row> rows = new ArrayList<>(); List<Row> rows = new ArrayList<>();
@ -262,8 +261,8 @@ public class Normalization {
} }
//compute the aggregation based on the operation //compute the aggregation based on the operation
DataRowsFacade aggregated = dataRows(data.get().agg(expressions)); Dataset<Row> aggregated = data.agg(expressions);
String[] columns2 = aggregated.get().columns(); String[] columns2 = aggregated.columns();
//strip out the op name and parentheses from the columns //strip out the op name and parentheses from the columns
Map<String, String> opReplace = new TreeMap<>(); Map<String, String> opReplace = new TreeMap<>();
for (String s : columns2) { for (String s : columns2) {
@ -278,20 +277,20 @@ public class Normalization {
//get rid of the operation name in the column //get rid of the operation name in the column
DataRowsFacade rearranged = null; Dataset<Row> rearranged = null;
for (Map.Entry<String, String> entries : opReplace.entrySet()) { for (Map.Entry<String, String> entries : opReplace.entrySet()) {
//first column //first column
if (rearranged == null) { if (rearranged == null) {
rearranged = dataRows(aggregated.get().withColumnRenamed(entries.getKey(), entries.getValue())); rearranged = aggregated.withColumnRenamed(entries.getKey(), entries.getValue());
} }
//rearranged is just a copy of aggregated at this point //rearranged is just a copy of aggregated at this point
else else
rearranged = dataRows(rearranged.get().withColumnRenamed(entries.getKey(), entries.getValue())); rearranged = rearranged.withColumnRenamed(entries.getKey(), entries.getValue());
} }
rearranged = dataRows(rearranged.get().select(DataFrames.toColumns(columns))); rearranged = rearranged.select(DataFrames.toColumns(columns));
//op //op
rows.addAll(rearranged.get().collectAsList()); rows.addAll(rearranged.collectAsList());
} }
@ -307,8 +306,8 @@ public class Normalization {
* @param max the maximum value * @param max the maximum value
* @return the normalized dataframe per column * @return the normalized dataframe per column
*/ */
public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max, List<String> skipColumns) { public static Dataset<Row> normalize(Dataset<Row> dataFrame, double min, double max, List<String> skipColumns) {
List<String> columnsList = DataFrames.toList(dataFrame.get().columns()); List<String> columnsList = DataFrames.toList(dataFrame.columns());
columnsList.removeAll(skipColumns); columnsList.removeAll(skipColumns);
String[] columnNames = DataFrames.toArray(columnsList); String[] columnNames = DataFrames.toArray(columnsList);
//first row is min second row is max, each column in a row is for a particular column //first row is min second row is max, each column in a row is for a particular column
@ -321,8 +320,8 @@ public class Normalization {
if (maxSubMin == 0) if (maxSubMin == 0)
maxSubMin = 1; maxSubMin = 1;
Column newCol = dataFrame.get().col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min); Column newCol = dataFrame.col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min);
dataFrame = dataRows(dataFrame.get().withColumn(columnName, newCol)); dataFrame = dataFrame.withColumn(columnName, newCol);
} }
@ -340,7 +339,7 @@ public class Normalization {
*/ */
public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, double max, public static JavaRDD<List<Writable>> normalize(Schema schema, JavaRDD<List<Writable>> data, double min, double max,
List<String> skipColumns) { List<String> skipColumns) {
DataRowsFacade frame = DataFrames.toDataFrame(schema, data); Dataset<Row> frame = DataFrames.toDataFrame(schema, data);
return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond(); return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond();
} }
@ -387,7 +386,7 @@ public class Normalization {
excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN); excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN);
excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN); excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN);
} }
DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, data); Dataset<Row> frame = DataFrames.toDataFrameSequence(schema, data);
return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond(); return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond();
} }
@ -398,7 +397,7 @@ public class Normalization {
* @param dataFrame the dataframe to scale * @param dataFrame the dataframe to scale
* @return the normalized dataframe per column * @return the normalized dataframe per column
*/ */
public static DataRowsFacade normalize(DataRowsFacade dataFrame, List<String> skipColumns) { public static Dataset<Row> normalize(Dataset<Row> dataFrame, List<String> skipColumns) {
return normalize(dataFrame, 0, 1, skipColumns); return normalize(dataFrame, 0, 1, skipColumns);
} }

View File

@ -16,9 +16,10 @@
package org.datavec.spark.transform.analysis; package org.datavec.spark.transform.analysis;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import java.util.Iterator;
import java.util.List; import java.util.List;
/** /**
@ -27,10 +28,11 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
public class SequenceFlatMapFunction extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, List<Writable>> { public class SequenceFlatMapFunction implements FlatMapFunction<List<List<Writable>>, List<Writable>> {
public SequenceFlatMapFunction() { @Override
super(new SequenceFlatMapFunctionAdapter()); public Iterator<List<Writable>> call(List<List<Writable>> collections) throws Exception {
return collections.iterator();
} }
} }

View File

@ -1,36 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform.analysis;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import java.util.List;
/**
* SequenceFlatMapFunction: very simple function used to flatten a sequence
* Typically used only internally for certain analysis operations
*
* @author Alex Black
*/
public class SequenceFlatMapFunctionAdapter implements FlatMapFunctionAdapter<List<List<Writable>>, List<Writable>> {
@Override
public Iterable<List<Writable>> call(List<List<Writable>> collections) throws Exception {
return collections;
}
}

View File

@ -16,11 +16,14 @@
package org.datavec.spark.transform.join; package org.datavec.spark.transform.join;
import org.nd4j.shade.guava.collect.Iterables;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.join.Join; import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import scala.Tuple2; import scala.Tuple2;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List; import java.util.List;
/** /**
@ -28,10 +31,89 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
public class ExecuteJoinFromCoGroupFlatMapFunction extends public class ExecuteJoinFromCoGroupFlatMapFunction implements FlatMapFunction<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
BaseFlatMapFunctionAdaptee<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
private final Join join;
public ExecuteJoinFromCoGroupFlatMapFunction(Join join) { public ExecuteJoinFromCoGroupFlatMapFunction(Join join) {
super(new ExecuteJoinFromCoGroupFlatMapFunctionAdapter(join)); this.join = join;
}
@Override
public Iterator<List<Writable>> call(
Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>> t2)
throws Exception {
Iterable<List<Writable>> leftList = t2._2()._1();
Iterable<List<Writable>> rightList = t2._2()._2();
List<List<Writable>> ret = new ArrayList<>();
Join.JoinType jt = join.getJoinType();
switch (jt) {
case Inner:
//Return records where key columns appear in BOTH
//So if no values from left OR right: no return values
for (List<Writable> jvl : leftList) {
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
break;
case LeftOuter:
//Return all records from left, even if no corresponding right value (NullWritable in that case)
for (List<Writable> jvl : leftList) {
if (Iterables.size(rightList) == 0) {
List<Writable> joined = join.joinExamples(jvl, null);
ret.add(joined);
} else {
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
}
break;
case RightOuter:
//Return all records from right, even if no corresponding left value (NullWritable in that case)
for (List<Writable> jvr : rightList) {
if (Iterables.size(leftList) == 0) {
List<Writable> joined = join.joinExamples(null, jvr);
ret.add(joined);
} else {
for (List<Writable> jvl : leftList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
}
break;
case FullOuter:
//Return all records, even if no corresponding left/right value (NullWritable in that case)
if (Iterables.size(leftList) == 0) {
//Only right values
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(null, jvr);
ret.add(joined);
}
} else if (Iterables.size(rightList) == 0) {
//Only left values
for (List<Writable> jvl : leftList) {
List<Writable> joined = join.joinExamples(jvl, null);
ret.add(joined);
}
} else {
//Records from both left and right
for (List<Writable> jvl : leftList) {
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
}
break;
}
return ret.iterator();
} }
} }

View File

@ -1,119 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform.join;
import com.google.common.collect.Iterables;
import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import scala.Tuple2;
import java.util.ArrayList;
import java.util.List;
/**
* Execute a join
*
* @author Alex Black
*/
public class ExecuteJoinFromCoGroupFlatMapFunctionAdapter implements
FlatMapFunctionAdapter<Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>>, List<Writable>> {
private final Join join;
public ExecuteJoinFromCoGroupFlatMapFunctionAdapter(Join join) {
this.join = join;
}
@Override
public Iterable<List<Writable>> call(
Tuple2<List<Writable>, Tuple2<Iterable<List<Writable>>, Iterable<List<Writable>>>> t2)
throws Exception {
Iterable<List<Writable>> leftList = t2._2()._1();
Iterable<List<Writable>> rightList = t2._2()._2();
List<List<Writable>> ret = new ArrayList<>();
Join.JoinType jt = join.getJoinType();
switch (jt) {
case Inner:
//Return records where key columns appear in BOTH
//So if no values from left OR right: no return values
for (List<Writable> jvl : leftList) {
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
break;
case LeftOuter:
//Return all records from left, even if no corresponding right value (NullWritable in that case)
for (List<Writable> jvl : leftList) {
if (Iterables.size(rightList) == 0) {
List<Writable> joined = join.joinExamples(jvl, null);
ret.add(joined);
} else {
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
}
break;
case RightOuter:
//Return all records from right, even if no corresponding left value (NullWritable in that case)
for (List<Writable> jvr : rightList) {
if (Iterables.size(leftList) == 0) {
List<Writable> joined = join.joinExamples(null, jvr);
ret.add(joined);
} else {
for (List<Writable> jvl : leftList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
}
break;
case FullOuter:
//Return all records, even if no corresponding left/right value (NullWritable in that case)
if (Iterables.size(leftList) == 0) {
//Only right values
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(null, jvr);
ret.add(joined);
}
} else if (Iterables.size(rightList) == 0) {
//Only left values
for (List<Writable> jvl : leftList) {
List<Writable> joined = join.joinExamples(jvl, null);
ret.add(joined);
}
} else {
//Records from both left and right
for (List<Writable> jvl : leftList) {
for (List<Writable> jvr : rightList) {
List<Writable> joined = join.joinExamples(jvl, jvr);
ret.add(joined);
}
}
}
break;
}
return ret;
}
}

View File

@ -16,10 +16,12 @@
package org.datavec.spark.transform.join; package org.datavec.spark.transform.join;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.join.Join; import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import java.util.Collections;
import java.util.Iterator;
import java.util.List; import java.util.List;
/** /**
@ -29,10 +31,43 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
public class FilterAndFlattenJoinedValues extends BaseFlatMapFunctionAdaptee<JoinedValue, List<Writable>> { public class FilterAndFlattenJoinedValues implements FlatMapFunction<JoinedValue, List<Writable>> {
private final Join.JoinType joinType;
public FilterAndFlattenJoinedValues(Join.JoinType joinType) { public FilterAndFlattenJoinedValues(Join.JoinType joinType) {
super(new FilterAndFlattenJoinedValuesAdapter(joinType)); this.joinType = joinType;
}
@Override
public Iterator<List<Writable>> call(JoinedValue joinedValue) throws Exception {
boolean keep;
switch (joinType) {
case Inner:
//Only keep joined values where we have both left and right
keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
break;
case LeftOuter:
//Keep all values where left is not missing/null
keep = joinedValue.isHaveLeft();
break;
case RightOuter:
//Keep all values where right is not missing/null
keep = joinedValue.isHaveRight();
break;
case FullOuter:
//Keep all values
keep = true;
break;
default:
throw new RuntimeException("Unknown/not implemented join type: " + joinType);
}
if (keep) {
return Collections.singletonList(joinedValue.getValues()).iterator();
} else {
return Collections.emptyIterator();
}
} }
} }

View File

@ -1,71 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform.join;
import org.datavec.api.transform.join.Join;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import java.util.Collections;
import java.util.List;
/**
* Doing two things here:
* (a) filter out any unnecessary values, and
* (b) extract the List<Writable> values from the JoinedValue
*
* @author Alex Black
*/
public class FilterAndFlattenJoinedValuesAdapter implements FlatMapFunctionAdapter<JoinedValue, List<Writable>> {
private final Join.JoinType joinType;
public FilterAndFlattenJoinedValuesAdapter(Join.JoinType joinType) {
this.joinType = joinType;
}
@Override
public Iterable<List<Writable>> call(JoinedValue joinedValue) throws Exception {
boolean keep;
switch (joinType) {
case Inner:
//Only keep joined values where we have both left and right
keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight();
break;
case LeftOuter:
//Keep all values where left is not missing/null
keep = joinedValue.isHaveLeft();
break;
case RightOuter:
//Keep all values where right is not missing/null
keep = joinedValue.isHaveRight();
break;
case FullOuter:
//Keep all values
keep = true;
break;
default:
throw new RuntimeException("Unknown/not implemented join type: " + joinType);
}
if (keep) {
return Collections.singletonList(joinedValue.getValues());
} else {
return Collections.emptyList();
}
}
}

View File

@ -16,21 +16,69 @@
package org.datavec.spark.transform.sparkfunction; package org.datavec.spark.transform.sparkfunction;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.StructType;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.transform.DataFrames;
import java.util.List; import java.util.*;
/** /**
* Convert a record to a row * Convert a record to a row
* @author Adam Gibson * @author Adam Gibson
*/ */
public class SequenceToRows extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, Row> { public class SequenceToRows implements FlatMapFunction<List<List<Writable>>, Row> {
private Schema schema;
private StructType structType;
public SequenceToRows(Schema schema) { public SequenceToRows(Schema schema) {
super(new SequenceToRowsAdapter(schema)); this.schema = schema;
structType = DataFrames.fromSchemaSequence(schema);
} }
@Override
public Iterator<Row> call(List<List<Writable>> sequence) throws Exception {
if (sequence.size() == 0)
return Collections.emptyIterator();
String sequenceUUID = UUID.randomUUID().toString();
List<Row> out = new ArrayList<>(sequence.size());
int stepCount = 0;
for (List<Writable> step : sequence) {
Object[] values = new Object[step.size() + 2];
values[0] = sequenceUUID;
values[1] = stepCount++;
for (int i = 0; i < step.size(); i++) {
switch (schema.getColumnTypes().get(i)) {
case Double:
values[i + 2] = step.get(i).toDouble();
break;
case Integer:
values[i + 2] = step.get(i).toInt();
break;
case Long:
values[i + 2] = step.get(i).toLong();
break;
case Float:
values[i + 2] = step.get(i).toFloat();
break;
default:
throw new IllegalStateException(
"This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
}
}
Row row = new GenericRowWithSchema(values, structType);
out.add(row);
}
return out.iterator();
}
} }

View File

@ -1,87 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform.sparkfunction;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.StructType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import org.datavec.spark.transform.DataFrames;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
/**
* Convert a record to a row
* @author Adam Gibson
*/
public class SequenceToRowsAdapter implements FlatMapFunctionAdapter<List<List<Writable>>, Row> {
private Schema schema;
private StructType structType;
public SequenceToRowsAdapter(Schema schema) {
this.schema = schema;
structType = DataFrames.fromSchemaSequence(schema);
}
@Override
public Iterable<Row> call(List<List<Writable>> sequence) throws Exception {
if (sequence.size() == 0)
return Collections.emptyList();
String sequenceUUID = UUID.randomUUID().toString();
List<Row> out = new ArrayList<>(sequence.size());
int stepCount = 0;
for (List<Writable> step : sequence) {
Object[] values = new Object[step.size() + 2];
values[0] = sequenceUUID;
values[1] = stepCount++;
for (int i = 0; i < step.size(); i++) {
switch (schema.getColumnTypes().get(i)) {
case Double:
values[i + 2] = step.get(i).toDouble();
break;
case Integer:
values[i + 2] = step.get(i).toInt();
break;
case Long:
values[i + 2] = step.get(i).toLong();
break;
case Float:
values[i + 2] = step.get(i).toFloat();
break;
default:
throw new IllegalStateException(
"This api should not be used with strings , binary data or ndarrays. This is only for columnar data");
}
}
Row row = new GenericRowWithSchema(values, structType);
out.add(row);
}
return out;
}
}

View File

@ -16,19 +16,27 @@
package org.datavec.spark.transform.transform; package org.datavec.spark.transform.transform;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.sequence.SequenceSplit; import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import java.util.Iterator;
import java.util.List; import java.util.List;
/** /**
* Created by Alex on 17/03/2016. * Created by Alex on 17/03/2016.
*/ */
public class SequenceSplitFunction extends BaseFlatMapFunctionAdaptee<List<List<Writable>>, List<List<Writable>>> { public class SequenceSplitFunction implements FlatMapFunction<List<List<Writable>>, List<List<Writable>>> {
private final SequenceSplit split;
public SequenceSplitFunction(SequenceSplit split) { public SequenceSplitFunction(SequenceSplit split) {
super(new SequenceSplitFunctionAdapter(split)); this.split = split;
}
@Override
public Iterator<List<List<Writable>>> call(List<List<Writable>> collections) throws Exception {
return split.split(collections).iterator();
} }
} }

View File

@ -1,41 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform.transform;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import java.util.List;
/**
* Created by Alex on 17/03/2016.
*/
public class SequenceSplitFunctionAdapter
implements FlatMapFunctionAdapter<List<List<Writable>>, List<List<Writable>>> {
private final SequenceSplit split;
public SequenceSplitFunctionAdapter(SequenceSplit split) {
this.split = split;
}
@Override
public Iterable<List<List<Writable>>> call(List<List<Writable>> collections) throws Exception {
return split.split(collections);
}
}

View File

@ -16,19 +16,32 @@
package org.datavec.spark.transform.transform; package org.datavec.spark.transform.transform;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.TransformProcess;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee;
import java.util.Collections;
import java.util.Iterator;
import java.util.List; import java.util.List;
/** /**
* Spark function for executing a transform process * Spark function for executing a transform process
*/ */
public class SparkTransformProcessFunction extends BaseFlatMapFunctionAdaptee<List<Writable>, List<Writable>> { public class SparkTransformProcessFunction implements FlatMapFunction<List<Writable>, List<Writable>> {
private final TransformProcess transformProcess;
public SparkTransformProcessFunction(TransformProcess transformProcess) { public SparkTransformProcessFunction(TransformProcess transformProcess) {
super(new SparkTransformProcessFunctionAdapter(transformProcess)); this.transformProcess = transformProcess;
}
@Override
public Iterator<List<Writable>> call(List<Writable> v1) throws Exception {
List<Writable> newList = transformProcess.execute(v1);
if (newList == null)
return Collections.emptyIterator(); //Example was filtered out
else
return Collections.singletonList(newList).iterator();
} }
} }

View File

@ -1,45 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.writable.Writable;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import java.util.Collections;
import java.util.List;
/**
* Spark function for executing a transform process
*/
public class SparkTransformProcessFunctionAdapter implements FlatMapFunctionAdapter<List<Writable>, List<Writable>> {
private final TransformProcess transformProcess;
public SparkTransformProcessFunctionAdapter(TransformProcess transformProcess) {
this.transformProcess = transformProcess;
}
@Override
public Iterable<List<Writable>> call(List<Writable> v1) throws Exception {
List<Writable> newList = transformProcess.execute(v1);
if (newList == null)
return Collections.emptyList(); //Example was filtered out
else
return Collections.singletonList(newList);
}
}

View File

@ -1,41 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
/**
* FlatMapFunction adapter to
* hide incompatibilities between Spark 1.x and Spark 2.x
*
* This class should be used instead of direct referral to FlatMapFunction
*
*/
public class BaseFlatMapFunctionAdaptee<K, V> implements FlatMapFunction<K, V> {
protected final FlatMapFunctionAdapter<K, V> adapter;
public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter<K, V> adapter) {
this.adapter = adapter;
}
@Override
public Iterable<V> call(K k) throws Exception {
return adapter.call(k);
}
}

View File

@ -1,42 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import org.apache.spark.sql.DataFrame;
/**
* Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
*
* This class should be used instead of direct referral to DataFrame / Dataset
*
*/
public class DataRowsFacade {
private final DataFrame df;
private DataRowsFacade(DataFrame df) {
this.df = df;
}
public static DataRowsFacade dataRows(DataFrame df) {
return new DataRowsFacade(df);
}
public DataFrame get() {
return df;
}
}

View File

@ -1,42 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.datavec.spark.functions.FlatMapFunctionAdapter;
import java.util.Iterator;
/**
* FlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x
*
* This class should be used instead of direct referral to FlatMapFunction
*
*/
public class BaseFlatMapFunctionAdaptee<K, V> implements FlatMapFunction<K, V> {
protected final FlatMapFunctionAdapter<K, V> adapter;
public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter<K, V> adapter) {
this.adapter = adapter;
}
@Override
public Iterator<V> call(K k) throws Exception {
return adapter.call(k).iterator();
}
}

View File

@ -1,43 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.transform;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
/**
* Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x
*
* This class should be used instead of direct referral to DataFrame / Dataset
*
*/
public class DataRowsFacade {
private final Dataset<Row> df;
private DataRowsFacade(Dataset<Row> df) {
this.df = df;
}
public static DataRowsFacade dataRows(Dataset<Row> df) {
return new DataRowsFacade(df);
}
public Dataset<Row> get() {
return df;
}
}

View File

@ -16,7 +16,7 @@
package org.datavec.spark.storage; package org.datavec.spark.storage;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;

View File

@ -19,6 +19,8 @@ package org.datavec.spark.transform;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Column; import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
@ -46,9 +48,9 @@ public class DataFramesTests extends BaseSparkTest {
for (int i = 0; i < numColumns; i++) for (int i = 0; i < numColumns; i++)
builder.addColumnDouble(String.valueOf(i)); builder.addColumnDouble(String.valueOf(i));
Schema schema = builder.build(); Schema schema = builder.build();
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records)); Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records));
dataFrame.get().show(); dataFrame.show();
dataFrame.get().describe(DataFrames.toArray(schema.getColumnNames())).show(); dataFrame.describe(DataFrames.toArray(schema.getColumnNames())).show();
// System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames())); // System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames()));
// System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames())); // System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames()));
@ -77,12 +79,12 @@ public class DataFramesTests extends BaseSparkTest {
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema))); assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect()); assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
dataFrame.get().show(); dataFrame.show();
Column mean = DataFrames.mean(dataFrame, "0"); Column mean = DataFrames.mean(dataFrame, "0");
Column std = DataFrames.std(dataFrame, "0"); Column std = DataFrames.std(dataFrame, "0");
dataFrame.get().withColumn("0", dataFrame.get().col("0").minus(mean)).show(); dataFrame.withColumn("0", dataFrame.col("0").minus(mean)).show();
dataFrame.get().withColumn("0", dataFrame.get().col("0").divide(std)).show(); dataFrame.withColumn("0", dataFrame.col("0").divide(std)).show();
/* DataFrame desc = dataFrame.describe(dataFrame.columns()); /* DataFrame desc = dataFrame.describe(dataFrame.columns());
dataFrame.show(); dataFrame.show();

View File

@ -17,6 +17,7 @@
package org.datavec.spark.transform; package org.datavec.spark.transform;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
@ -24,11 +25,13 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -50,36 +53,35 @@ public class NormalizationTests extends BaseSparkTest {
for (int i = 0; i < numColumns; i++) for (int i = 0; i < numColumns; i++)
builder.addColumnDouble(String.valueOf(i)); builder.addColumnDouble(String.valueOf(i));
Nd4j.getRandom().setSeed(12345);
INDArray arr = Nd4j.rand(DataType.FLOAT, 5, numColumns);
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
List<Writable> record = new ArrayList<>(numColumns); List<Writable> record = new ArrayList<>(numColumns);
data.add(record); data.add(record);
for (int j = 0; j < numColumns; j++) { for (int j = 0; j < numColumns; j++) {
record.add(new DoubleWritable(1.0)); record.add(new DoubleWritable(arr.getDouble(i, j)));
} }
} }
INDArray arr = RecordConverter.toMatrix(data);
Schema schema = builder.build(); Schema schema = builder.build();
JavaRDD<List<Writable>> rdd = sc.parallelize(data); JavaRDD<List<Writable>> rdd = sc.parallelize(data);
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
//assert equivalent to the ndarray pre processing //assert equivalent to the ndarray pre processing
NormalizerStandardize standardScaler = new NormalizerStandardize();
standardScaler.fit(new DataSet(arr.dup(), arr.dup()));
INDArray standardScalered = arr.dup();
standardScaler.transform(new DataSet(standardScalered, standardScalered));
DataNormalization zeroToOne = new NormalizerMinMaxScaler(); DataNormalization zeroToOne = new NormalizerMinMaxScaler();
zeroToOne.fit(new DataSet(arr.dup(), arr.dup())); zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
INDArray zeroToOnes = arr.dup(); INDArray zeroToOnes = arr.dup();
zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes)); zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.get().columns()); List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.columns());
INDArray assertion = DataFrames.toMatrix(rows); INDArray assertion = DataFrames.toMatrix(rows);
//compare standard deviation INDArray expStd = arr.std(true, true, 0);
assertTrue(standardScaler.getStd().equalsWithEps(assertion.getRow(0), 1e-1)); INDArray std = assertion.getRow(0, true);
assertTrue(expStd.equalsWithEps(std, 1e-3));
//compare mean //compare mean
assertTrue(standardScaler.getMean().equalsWithEps(assertion.getRow(1), 1e-1)); INDArray expMean = arr.mean(true, 0);
assertTrue(expMean.equalsWithEps(assertion.getRow(1, true), 1e-3));
} }
@ -109,10 +111,10 @@ public class NormalizationTests extends BaseSparkTest {
assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema))); assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect()); assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());
DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); Dataset<Row> dataFrame = DataFrames.toDataFrame(schema, rdd);
dataFrame.get().show(); dataFrame.show();
Normalization.zeromeanUnitVariance(dataFrame).get().show(); Normalization.zeromeanUnitVariance(dataFrame).show();
Normalization.normalize(dataFrame).get().show(); Normalization.normalize(dataFrame).show();
//assert equivalent to the ndarray pre processing //assert equivalent to the ndarray pre processing
NormalizerStandardize standardScaler = new NormalizerStandardize(); NormalizerStandardize standardScaler = new NormalizerStandardize();

View File

@ -330,53 +330,6 @@
</plugins> </plugins>
</build> </build>
<profiles>
<!-- For running unit tests with nd4j-native: "mvn clean test -P test-nd4j-native"
Note that this excludes DL4J-cuda -->
<profile>
<id>test-nd4j-native</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${dl4j-test-resources.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
<profile>
<id>test-nd4j-cuda-10.1</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${dl4j-test-resources.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<!-- Default to ALL modules here, unlike nd4j-native -->
</profile>
</profiles>
<reporting> <reporting>
<plugins> <plugins>
<plugin> <plugin>
@ -391,4 +344,42 @@
</plugin> </plugin>
</plugins> </plugins>
</reporting> </reporting>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.1</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project> </project>

View File

@ -52,18 +52,6 @@ public class DL4JSystemProperties {
*/ */
public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl"; public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl";
/**
* Applicability: deeplearning4j-nn<br>
* Description: Used for loading legacy format JSON containing custom layers. This system property is provided as an
* alternative to {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}. Classes are specified in
* comma-separated format.<br>
* This is required ONLY when ALL of the following conditions are met:<br>
* 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND<br>
* 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J), AND<br>
* 3. You haven't already called {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}
*/
public static final String CUSTOM_REGISTRATION_PROPERTY = "org.deeplearning4j.config.custom.legacyclasses";
/** /**
* Applicability: deeplearning4j-nn<br> * Applicability: deeplearning4j-nn<br>
* Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled * Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled

View File

@ -29,23 +29,11 @@
<artifactId>nd4j-api</artifactId> <artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-common</artifactId> <artifactId>nd4j-common</artifactId>
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>
@ -82,7 +70,6 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.deeplearning4j</groupId> <groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId> <artifactId>deeplearning4j-nn</artifactId>
@ -109,18 +96,6 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId> <artifactId>nd4j-api</artifactId>
@ -132,8 +107,6 @@
<version>${commonslang.version}</version> <version>${commonslang.version}</version>
</dependency> </dependency>
<!-- ND4J Shaded Jackson Dependency --> <!-- ND4J Shaded Jackson Dependency -->
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
@ -141,7 +114,6 @@
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.projectlombok</groupId> <groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId> <artifactId>lombok</artifactId>
@ -180,9 +152,26 @@
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.1</id> <id>test-nd4j-cuda-10.1</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.datasets.datavec; package org.deeplearning4j.datasets.datavec;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.datasets.datavec; package org.deeplearning4j.datasets.datavec;
import com.google.common.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.commons.compress.utils.IOUtils; import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;

View File

@ -22,12 +22,16 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.fail; import java.util.Map;
import static org.junit.Assert.*;
/** /**
* A set of tests to ensure that useful exceptions are thrown on invalid input * A set of tests to ensure that useful exceptions are thrown on invalid input
@ -267,23 +271,44 @@ public class TestInvalidInput extends BaseDL4JTest {
//Idea: Using rnnTimeStep with a different number of examples between calls //Idea: Using rnnTimeStep with a different number of examples between calls
//(i.e., not calling reset between time steps) //(i.e., not calling reset between time steps)
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() for(String layerType : new String[]{"simple", "lstm", "graves"}) {
.layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build())
.layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); Layer l;
net.init(); switch (layerType){
case "simple":
l = new SimpleRnn.Builder().nIn(5).nOut(5).build();
break;
case "lstm":
l = new LSTM.Builder().nIn(5).nOut(5).build();
break;
case "graves":
l = new GravesLSTM.Builder().nIn(5).nOut(5).build();
break;
default:
throw new RuntimeException();
}
net.rnnTimeStep(Nd4j.create(3, 5, 10)); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list()
.layer(l)
.layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build();
try { MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.rnnTimeStep(Nd4j.create(5, 5, 10)); net.init();
fail("Expected DL4JException");
} catch (DL4JException e) { net.rnnTimeStep(Nd4j.create(3, 5, 10));
System.out.println("testInvalidRnnTimeStep(): " + e.getMessage());
} catch (Exception e) { Map<String, INDArray> m = net.rnnGetPreviousState(0);
e.printStackTrace(); assertNotNull(m);
fail("Expected DL4JException"); assertFalse(m.isEmpty());
try {
net.rnnTimeStep(Nd4j.create(5, 5, 10));
fail("Expected Exception - " + layerType);
} catch (Exception e) {
// e.printStackTrace();
String msg = e.getMessage();
assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch"));
}
} }
} }
} }

Some files were not shown because too many files have changed in this diff Show More