diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml index 5d06744d3..064dd3ecd 100644 --- a/arbiter/arbiter-core/pom.xml +++ b/arbiter/arbiter-core/pom.xml @@ -72,6 +72,11 @@ test + + joda-time + joda-time + ${jodatime.version} + @@ -80,4 +85,13 @@ ${nd4j.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.1 + + diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java index 9e50065e6..a9c0933c4 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java @@ -16,7 +16,7 @@ package org.deeplearning4j.arbiter.optimize.distribution; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Getter; import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.exception.NumberIsTooLargeException; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java index fa503ef6d..0e04a130b 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java @@ -16,7 +16,7 @@ package org.deeplearning4j.arbiter.optimize.runner; -import com.google.common.util.concurrent.ListenableFuture; +import org.nd4j.shade.guava.util.concurrent.ListenableFuture; import lombok.AllArgsConstructor; import lombok.Data; import lombok.extern.slf4j.Slf4j; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java index a3992b09a..6982090f1 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java @@ -16,9 +16,9 @@ package org.deeplearning4j.arbiter.optimize.runner; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.MoreExecutors; +import org.nd4j.shade.guava.util.concurrent.ListenableFuture; +import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService; +import org.nd4j.shade.guava.util.concurrent.MoreExecutors; import lombok.Setter; import org.deeplearning4j.arbiter.optimize.api.*; import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java index 9e30a06f6..8cfb07723 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java @@ -43,13 +43,15 @@ public class JsonMapper { mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); yamlMapper = new ObjectMapper(new YAMLFactory()); - mapper.registerModule(new JodaModule()); - mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - mapper.enable(SerializationFeature.INDENT_OUTPUT); - mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + yamlMapper.registerModule(new JodaModule()); + yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + yamlMapper.enable(SerializationFeature.INDENT_OUTPUT); + yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); + yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); } private JsonMapper() {} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java index f10c593e0..5b35220e9 100644 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java +++ b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java @@ -39,6 +39,7 @@ public class YamlMapper { mapper.enable(SerializationFeature.INDENT_OUTPUT); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); } diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java index 3572b187b..6f1b336bb 100644 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java @@ -59,6 +59,7 @@ public class TestJson { om.enable(SerializationFeature.INDENT_OUTPUT); om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); return om; } diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml index 91c5a11ed..85afe7a6b 100644 --- a/arbiter/arbiter-deeplearning4j/pom.xml +++ b/arbiter/arbiter-deeplearning4j/pom.xml @@ -38,13 +38,6 @@ ${dl4j.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - - junit junit @@ -64,6 +57,20 @@ jackson ${nd4j.version} + + + com.google.code.gson + gson + ${gson.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.1 + + diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java index 77c31707a..0a5e33d27 100644 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java +++ b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java @@ -16,7 +16,7 @@ package org.deeplearning4j.arbiter.layers; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.AccessLevel; import lombok.Data; import lombok.EqualsAndHashCode; diff --git a/arbiter/arbiter-server/pom.xml b/arbiter/arbiter-server/pom.xml index caef6b6bd..5d14fa6a0 100644 --- a/arbiter/arbiter-server/pom.xml +++ b/arbiter/arbiter-server/pom.xml @@ -49,11 +49,14 @@ ${junit.version} test - - org.nd4j - nd4j-native - ${nd4j.version} - test - + + + + test-nd4j-native + + + test-nd4j-cuda-10.1 + + diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml index 7d83c737c..56d1013bf 100644 --- a/arbiter/arbiter-ui/pom.xml +++ b/arbiter/arbiter-ui/pom.xml @@ -97,15 +97,16 @@ + + + test-nd4j-native + + + test-nd4j-cuda-10.1 + - - com.google.guava - guava - ${guava.version} - - org.deeplearning4j arbiter-core @@ -124,13 +125,6 @@ ${project.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - - junit junit @@ -139,9 +133,6 @@ - - - @@ -222,5 +213,4 @@ - diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java index 491756ae7..0ac5ad383 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java @@ -18,6 +18,7 @@ package org.deeplearning4j.arbiter.ui.misc; import org.nd4j.shade.jackson.annotation.JsonAutoDetect; +import org.nd4j.shade.jackson.annotation.PropertyAccessor; import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; @@ -45,12 +46,9 @@ public class JsonMapper { mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); mapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true); mapper.enable(SerializationFeature.INDENT_OUTPUT); - - mapper.setVisibilityChecker(mapper.getSerializationConfig().getDefaultVisibilityChecker() - .withFieldVisibility(JsonAutoDetect.Visibility.ANY) - .withGetterVisibility(JsonAutoDetect.Visibility.NONE) - .withSetterVisibility(JsonAutoDetect.Visibility.NONE) - .withCreatorVisibility(JsonAutoDetect.Visibility.NONE)); + mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); + mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); return mapper; } diff --git a/arbiter/pom.xml b/arbiter/pom.xml index f4380c58c..5f660c646 100644 --- a/arbiter/pom.xml +++ b/arbiter/pom.xml @@ -136,6 +136,31 @@ + + org.apache.maven.plugins + maven-enforcer-plugin + ${maven-enforcer-plugin.version} + + + test + enforce-test-resources + + enforce + + + ${skipTestResourceEnforcement} + + + test-nd4j-native,test-nd4j-cuda-10.1 + false + + + true + + + + + maven-javadoc-plugin ${maven-javadoc-plugin.version} @@ -287,4 +312,42 @@ + + + + test-nd4j-native + + + org.nd4j + nd4j-native + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + + diff --git a/change-scala-versions.sh b/change-scala-versions.sh index 5fb40c5cb..8968abbf3 100755 --- a/change-scala-versions.sh +++ b/change-scala-versions.sh @@ -20,9 +20,9 @@ set -e -VALID_VERSIONS=( 2.10 2.11 ) -SCALA_210_VERSION=$(grep -F -m 1 'scala210.version' pom.xml); SCALA_210_VERSION="${SCALA_210_VERSION#*>}"; SCALA_210_VERSION="${SCALA_210_VERSION%<*}"; +VALID_VERSIONS=( 2.11 2.12 ) SCALA_211_VERSION=$(grep -F -m 1 'scala211.version' pom.xml); SCALA_211_VERSION="${SCALA_211_VERSION#*>}"; SCALA_211_VERSION="${SCALA_211_VERSION%<*}"; +SCALA_212_VERSION=$(grep -F -m 1 'scala212.version' pom.xml); SCALA_212_VERSION="${SCALA_212_VERSION#*>}"; SCALA_212_VERSION="${SCALA_212_VERSION%<*}"; usage() { echo "Usage: $(basename $0) [-h|--help] @@ -45,19 +45,18 @@ check_scala_version() { exit 1 } - check_scala_version "$TO_VERSION" if [ $TO_VERSION = "2.11" ]; then - FROM_BINARY="_2\.10" + FROM_BINARY="_2\.12" TO_BINARY="_2\.11" - FROM_VERSION=$SCALA_210_VERSION + FROM_VERSION=$SCALA_212_VERSION TO_VERSION=$SCALA_211_VERSION else FROM_BINARY="_2\.11" - TO_BINARY="_2\.10" + TO_BINARY="_2\.12" FROM_VERSION=$SCALA_211_VERSION - TO_VERSION=$SCALA_210_VERSION + TO_VERSION=$SCALA_212_VERSION fi sed_i() { @@ -70,35 +69,24 @@ echo "Updating Scala versions in pom.xml files to Scala $1, from $FROM_VERSION t BASEDIR=$(dirname $0) -#Artifact ids, ending with "_2.10" or "_2.11". Spark, spark-mllib, kafka, etc. +#Artifact ids, ending with "_2.11" or "_2.12". Spark, spark-mllib, kafka, etc. find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(artifactId>.*\)'$FROM_BINARY'<\/artifactId>/\1'$TO_BINARY'<\/artifactId>/g' {}" \; -#Scala versions, like 2.10 +#Scala versions, like 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scala.version>\)'$FROM_VERSION'<\/scala.version>/\1'$TO_VERSION'<\/scala.version>/g' {}" \; -#Scala binary versions, like 2.10 +#Scala binary versions, like 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scala.binary.version>\)'${FROM_BINARY#_}'<\/scala.binary.version>/\1'${TO_BINARY#_}'<\/scala.binary.version>/g' {}" \; -#Scala versions, like scala-library 2.10.6 +#Scala versions, like scala-library 2.11.12 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(version>\)'$FROM_VERSION'<\/version>/\1'$TO_VERSION'<\/version>/g' {}" \; -#Scala maven plugin, 2.10 +#Scala maven plugin, 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \; - -#Edge case for Korean NLP artifact not following conventions: https://github.com/deeplearning4j/deeplearning4j/issues/6306 -#https://github.com/deeplearning4j/deeplearning4j/issues/6306 -if [[ $TO_VERSION == 2.11* ]]; then - sed_i 's/korean-text-scala-2.10<\/artifactId>/korean-text<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml - sed_i 's/4.2.0<\/version>/4.4<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml -else - sed_i 's/korean-text<\/artifactId>/korean-text-scala-2.10<\/artifactId>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml - sed_i 's/4.4<\/version>/4.2.0<\/version>/g' deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml -fi - echo "Done updating Scala versions."; diff --git a/change-spark-versions.sh b/change-spark-versions.sh deleted file mode 100755 index 06a9b4d55..000000000 --- a/change-spark-versions.sh +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env bash - -################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -# This shell script is adapted from Apache Flink (in turn, adapted from Apache Spark) some modifications. - -set -e - -VALID_VERSIONS=( 1 2 ) -SPARK_2_VERSION="2\.1\.0" -SPARK_1_VERSION="1\.6\.3" - -usage() { - echo "Usage: $(basename $0) [-h|--help] -where : - -h| --help Display this help text - valid spark version values : ${VALID_VERSIONS[*]} -" 1>&2 - exit 1 -} - -if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then - usage -fi - -TO_VERSION=$1 - -check_spark_version() { - for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done - echo "Invalid Spark version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 - exit 1 -} - - -check_spark_version "$TO_VERSION" - -if [ $TO_VERSION = "2" ]; then - FROM_BINARY="1" - TO_BINARY="2" - FROM_VERSION=$SPARK_1_VERSION - TO_VERSION=$SPARK_2_VERSION -else - FROM_BINARY="2" - TO_BINARY="1" - FROM_VERSION=$SPARK_2_VERSION - TO_VERSION=$SPARK_1_VERSION -fi - -sed_i() { - sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2" -} - -export -f sed_i - -echo "Updating Spark versions in pom.xml files to Spark $1"; - -BASEDIR=$(dirname $0) - -# 1 -find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ - -exec bash -c "sed_i 's/\(spark.major.version>\)'$FROM_BINARY'<\/spark.major.version>/\1'$TO_BINARY'<\/spark.major.version>/g' {}" \; - -# 1.6.3 -find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ - -exec bash -c "sed_i 's/\(spark.version>\)'$FROM_VERSION'<\/spark.version>/\1'$TO_VERSION'<\/spark.version>/g' {}" \; - -#Spark versions, like xxx_spark_2xxx OR xxx_spark_2xxx -find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ - -exec bash -c "sed_i 's/\(version>.*_spark_\)'$FROM_BINARY'\(.*\)version>/\1'$TO_BINARY'\2version>/g' {}" \; - -echo "Done updating Spark versions."; diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml index 73ecef1e5..022f2e38b 100644 --- a/datavec/datavec-api/pom.xml +++ b/datavec/datavec-api/pom.xml @@ -26,7 +26,6 @@ datavec-api - org.apache.commons @@ -98,13 +97,6 @@ ${stream.analytics.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - - net.sf.opencsv @@ -125,7 +117,6 @@ - test-nd4j-native diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java index 44098e6cc..32f8f7cc5 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Transform.java @@ -16,7 +16,6 @@ package org.datavec.api.transform; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -27,8 +26,7 @@ import java.util.List; /**A Transform converts an example to another example, or a sequence to another sequence */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.TransformHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Transform extends Serializable, ColumnOp { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java index 24a029179..7c57f4daa 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java @@ -67,6 +67,7 @@ import org.joda.time.DateTimeZone; import org.nd4j.linalg.primitives.Pair; import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import java.io.IOException; import java.io.Serializable; @@ -417,6 +418,16 @@ public class TransformProcess implements Serializable { public static TransformProcess fromJson(String json) { try { return JsonMappers.getMapper().readValue(json, TransformProcess.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + try{ + return JsonMappers.getLegacyMapper().readValue(json, TransformProcess.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (IOException e) { //TODO proper exception message throw new RuntimeException(e); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java index 6b069a9ec..467db70f0 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/DataAnalysis.java @@ -23,12 +23,14 @@ import org.datavec.api.transform.analysis.columns.ColumnAnalysis; import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; +import org.datavec.api.transform.serde.JsonMappers; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.databind.node.ArrayNode; import java.io.IOException; @@ -116,6 +118,16 @@ public class DataAnalysis implements Serializable { public static DataAnalysis fromJson(String json) { try{ return new JsonSerializer().getObjectMapper().readValue(json, DataAnalysis.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, DataAnalysis.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (Exception e){ //Legacy format ObjectMapper om = new JsonSerializer().getObjectMapper(); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java index ecc333d2d..6156ead40 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/SequenceDataAnalysis.java @@ -21,9 +21,10 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.analysis.columns.ColumnAnalysis; import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis; import org.datavec.api.transform.schema.Schema; +import org.datavec.api.transform.serde.JsonMappers; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import java.io.IOException; import java.util.List; @@ -50,6 +51,16 @@ public class SequenceDataAnalysis extends DataAnalysis { public static SequenceDataAnalysis fromJson(String json){ try{ return new JsonSerializer().getObjectMapper().readValue(json, SequenceDataAnalysis.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, SequenceDataAnalysis.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (IOException e){ throw new RuntimeException(e); } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java index dd43315d2..c86584ede 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/ColumnAnalysis.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.analysis.columns; import org.datavec.api.transform.ColumnType; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -27,8 +26,7 @@ import java.io.Serializable; * Interface for column analysis */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.ColumnAnalysisHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ColumnAnalysis extends Serializable { long getCountTotal(); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java index 83a96bfcf..6bd5b98ac 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java @@ -18,7 +18,6 @@ package org.datavec.api.transform.condition; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -35,8 +34,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.ConditionHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Condition extends Serializable, ColumnOp { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java index 5aa672f76..16870e9f9 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java @@ -18,7 +18,6 @@ package org.datavec.api.transform.filter; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -33,8 +32,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.FilterHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Filter extends Serializable, ColumnOp { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java index cde86cf90..6889fbf31 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.metadata; import org.datavec.api.transform.ColumnType; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -32,8 +31,7 @@ import java.io.Serializable; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.ColumnMetaDataHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ColumnMetaData extends Serializable, Cloneable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java index 3341e0b6d..8ef67bacb 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java @@ -23,8 +23,8 @@ import org.datavec.api.writable.Writable; import java.util.List; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; +import static org.nd4j.shade.guava.base.Preconditions.checkArgument; +import static org.nd4j.shade.guava.base.Preconditions.checkNotNull; /** * A variant of {@link DispatchOp} that for each operation, tests the input list of {@Writable} elements for a {@link Condition}, diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java index 5bcf81f7a..1e5177c68 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java @@ -23,7 +23,6 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.LongMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.comparator.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonInclude; @@ -50,8 +49,7 @@ import java.util.List; @EqualsAndHashCode(exclude = {"inputSchema"}) @JsonIgnoreProperties({"inputSchema"}) @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.CalculateSortedRankHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public class CalculateSortedRank implements Serializable, ColumnOp { private final String newColumnName; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java index a9e167943..1c16ebcce 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java @@ -21,7 +21,6 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.metadata.*; import org.datavec.api.transform.serde.JsonMappers; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.*; import org.joda.time.DateTimeZone; import org.nd4j.shade.jackson.annotation.*; @@ -29,9 +28,11 @@ import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.datatype.joda.JodaModule; +import java.io.IOException; import java.io.Serializable; import java.util.*; @@ -48,8 +49,7 @@ import java.util.*; */ @JsonIgnoreProperties({"columnNames", "columnNamesIndex"}) @EqualsAndHashCode -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.SchemaHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @Data public class Schema implements Serializable { @@ -358,6 +358,16 @@ public class Schema implements Serializable { public static Schema fromJson(String json) { try{ return JsonMappers.getMapper().readValue(json, Schema.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, Schema.class); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (Exception e){ //TODO better exceptions throw new RuntimeException(e); @@ -379,21 +389,6 @@ public class Schema implements Serializable { } } - private static Schema fromJacksonString(String str, JsonFactory factory) { - ObjectMapper om = new ObjectMapper(factory); - om.registerModule(new JodaModule()); - om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - om.enable(SerializationFeature.INDENT_OUTPUT); - om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); - try { - return om.readValue(str, Schema.class); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - public static class Builder { List columnMetaData = new ArrayList<>(); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java index e4a09f6e9..a5677d1f5 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.sequence; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -30,8 +29,7 @@ import java.util.List; * Compare the time steps of a sequence */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.SequenceComparatorHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface SequenceComparator extends Comparator>, Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java index bef0b4ccc..3471dbaa3 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.sequence; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -32,8 +31,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.SequenceSplitHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface SequenceSplit extends Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java index e25f85253..1456af8d7 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.sequence.window; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -36,8 +35,7 @@ import java.util.List; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.WindowFunctionHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface WindowFunction extends Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java index 652556ce4..bfa114697 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java @@ -16,44 +16,17 @@ package org.datavec.api.transform.serde; -import lombok.AllArgsConstructor; -import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.datavec.api.io.WritableComparator; -import org.datavec.api.transform.Transform; -import org.datavec.api.transform.analysis.columns.ColumnAnalysis; -import org.datavec.api.transform.condition.column.ColumnCondition; -import org.datavec.api.transform.filter.Filter; -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.rank.CalculateSortedRank; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.transform.sequence.window.WindowFunction; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.primitives.Pair; -import org.nd4j.serde.json.LegacyIActivationDeserializer; -import org.nd4j.serde.json.LegacyILossFunctionDeserializer; +import org.datavec.api.transform.serde.legacy.LegacyJsonFormat; import org.nd4j.shade.jackson.annotation.JsonAutoDetect; -import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.PropertyAccessor; -import org.nd4j.shade.jackson.databind.*; -import org.nd4j.shade.jackson.databind.cfg.MapperConfig; -import org.nd4j.shade.jackson.databind.introspect.Annotated; -import org.nd4j.shade.jackson.databind.introspect.AnnotatedClass; -import org.nd4j.shade.jackson.databind.introspect.AnnotationMap; -import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector; -import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder; +import org.nd4j.shade.jackson.databind.DeserializationFeature; +import org.nd4j.shade.jackson.databind.MapperFeature; +import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import org.nd4j.shade.jackson.datatype.joda.JodaModule; -import java.lang.annotation.Annotation; -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; - /** * JSON mappers for deserializing neural net configurations, etc. * @@ -62,38 +35,9 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class JsonMappers { - /** - * This system property is provided as an alternative to {@link #registerLegacyCustomClassesForJSON(Class[])} - * Classes can be specified in comma-separated format - */ - public static String CUSTOM_REGISTRATION_PROPERTY = "org.datavec.config.custom.legacyclasses"; - - static { - String p = System.getProperty(CUSTOM_REGISTRATION_PROPERTY); - if(p != null && !p.isEmpty()){ - String[] split = p.split(","); - List> list = new ArrayList<>(); - for(String s : split){ - try{ - Class c = Class.forName(s); - list.add(c); - } catch (Throwable t){ - log.warn("Error parsing {} system property: class \"{}\" could not be loaded",CUSTOM_REGISTRATION_PROPERTY, s, t); - } - } - - if(list.size() > 0){ - try { - registerLegacyCustomClassesForJSONList(list); - } catch (Throwable t){ - log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",CUSTOM_REGISTRATION_PROPERTY, t); - } - } - } - } - private static ObjectMapper jsonMapper; private static ObjectMapper yamlMapper; + private static ObjectMapper legacyMapper; //For 1.0.0-alpha and earlier TransformProcess etc static { jsonMapper = new ObjectMapper(); @@ -102,117 +46,12 @@ public class JsonMappers { configureMapper(yamlMapper); } - private static Map legacyMappers = new ConcurrentHashMap<>(); - - - /** - * Register a set of classes (Transform, Filter, etc) for JSON deserialization.
- *
- * This is required ONLY when BOTH of the following conditions are met:
- * 1. You want to load a serialized TransformProcess, saved in 1.0.0-alpha or before, AND
- * 2. The serialized TransformProcess has a custom Transform, Filter, etc (i.e., one not defined in DL4J)
- *
- * By passing the classes of these custom classes here, DataVec should be able to deserialize them, in spite of the JSON - * format change between versions. - * - * @param classes Classes to register - */ - public static void registerLegacyCustomClassesForJSON(Class... classes) { - registerLegacyCustomClassesForJSONList(Arrays.>asList(classes)); - } - - /** - * @see #registerLegacyCustomClassesForJSON(Class[]) - */ - public static void registerLegacyCustomClassesForJSONList(List> classes){ - //Default names (i.e., old format for custom JSON format) - List> list = new ArrayList<>(); - for(Class c : classes){ - list.add(new Pair(c.getSimpleName(), c)); + public static synchronized ObjectMapper getLegacyMapper(){ + if(legacyMapper == null){ + legacyMapper = LegacyJsonFormat.legacyMapper(); + configureMapper(legacyMapper); } - registerLegacyCustomClassesForJSON(list); - } - - /** - * Set of classes that can be registered for legacy deserialization. - */ - private static List> REGISTERABLE_CUSTOM_CLASSES = (List>) Arrays.>asList( - Transform.class, - ColumnAnalysis.class, - ColumnCondition.class, - Filter.class, - ColumnMetaData.class, - CalculateSortedRank.class, - Schema.class, - SequenceComparator.class, - SequenceSplit.class, - WindowFunction.class, - Writable.class, - WritableComparator.class - ); - - /** - * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution - * ONLY) for JSON deserialization, with custom names.
- * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])} - * but is added in case it is required in non-standard circumstances. - */ - public static void registerLegacyCustomClassesForJSON(List> classes){ - for(Pair p : classes){ - String s = p.getFirst(); - Class c = p.getRight(); - //Check if it's a valid class to register... - boolean found = false; - for( Class c2 : REGISTERABLE_CUSTOM_CLASSES){ - if(c2.isAssignableFrom(c)){ - Map map = LegacyMappingHelper.legacyMappingForClass(c2); - map.put(p.getFirst(), p.getSecond().getName()); - found = true; - } - } - - if(!found){ - throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " + - c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES); - } - } - } - - - /** - * Get the legacy JSON mapper for the specified class.
- * - * NOTE: This is intended for internal backward-compatibility use. - * - * Note to developers: The following JSON mappers are for handling legacy format JSON. - * Note that after 1.0.0-alpha, the JSON subtype format for Transforms, Filters, Conditions etc were changed from - * a wrapper object, to an "@class" field. However, to not break all saved transforms networks, these mappers are - * part of the solution.
- *
- * How legacy loading works (same pattern for all types - Transform, Filter, Condition etc)
- * 1. Transforms etc JSON that has a "@class" field are deserialized as normal
- * 2. Transforms JSON that don't have such a field are mapped (via Layer @JsonTypeInfo) to LegacyMappingHelper.TransformHelper
- * 3. LegacyMappingHelper.TransformHelper has a @JsonDeserialize annotation - we use LegacyMappingHelper.LegacyTransformDeserializer to handle it
- * 4. LegacyTransformDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names - * 5. BaseLegacyDeserializer (that LegacyTransformDeserializer extends) does a lookup and handles the deserialization - * - * Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format, - * as it'll fail due to not having the expected "@class" annotation. - * Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified - * class anyway. The ignoring is done via an annotation introspector, defined below in this class. - * However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of - * all types - if we did, then any nested types would fail (i.e., an Condition in a Transform - the Transform couldn't - * be deserialized correctly, as the annotation would be ignored). - * - */ - public static synchronized ObjectMapper getLegacyMapperFor(@NonNull Class clazz){ - if(!legacyMappers.containsKey(clazz)){ - ObjectMapper m = new ObjectMapper(); - configureMapper(m); - m.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(clazz))); - legacyMappers.put(clazz, m); - } - return legacyMappers.get(clazz); + return legacyMapper; } /** @@ -237,61 +76,7 @@ public class JsonMappers { ret.enable(SerializationFeature.INDENT_OUTPUT); ret.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); ret.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); + ret.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); //Need this otherwise JsonProperty annotations on constructors won't be seen } - - /** - * Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc. - * This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring - * a set of JsonTypeInfo annotations - */ - @AllArgsConstructor - private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector { - - private List classList; - - @Override - protected TypeResolverBuilder _findTypeResolver(MapperConfig config, Annotated ann, JavaType baseType) { - if(ann instanceof AnnotatedClass){ - AnnotatedClass c = (AnnotatedClass)ann; - Class annClass = c.getAnnotated(); - - boolean isAssignable = false; - for(Class c2 : classList){ - if(c2.isAssignableFrom(annClass)){ - isAssignable = true; - break; - } - } - - if( isAssignable ){ - AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations(); - if(annotations == null || annotations.annotations() == null){ - //Probably not necessary - but here for safety - return super._findTypeResolver(config, ann, baseType); - } - - AnnotationMap newMap = null; - for(Annotation a : annotations.annotations()){ - Class annType = a.annotationType(); - if(annType == JsonTypeInfo.class){ - //Ignore the JsonTypeInfo annotation on the Layer class - continue; - } - if(newMap == null){ - newMap = new AnnotationMap(); - } - newMap.add(a); - } - if(newMap == null) - return null; - - //Pass the remaining annotations (if any) to the original introspector - AnnotatedClass ann2 = c.withAnnotations(newMap); - return super._findTypeResolver(config, ann2, baseType); - } - } - return super._findTypeResolver(config, ann, baseType); - } - } } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java deleted file mode 100644 index 5a9b48a7c..000000000 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/GenericLegacyDeserializer.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.api.transform.serde.legacy; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.Getter; -import org.datavec.api.transform.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.Map; - -@AllArgsConstructor -@Data -public class GenericLegacyDeserializer extends BaseLegacyDeserializer { - - @Getter - protected final Class deserializedType; - @Getter - protected final Map legacyNamesMap; - - @Override - public ObjectMapper getLegacyJsonMapper() { - return JsonMappers.getLegacyMapperFor(getDeserializedType()); - } -} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java new file mode 100644 index 000000000..8df741a49 --- /dev/null +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java @@ -0,0 +1,267 @@ +package org.datavec.api.transform.serde.legacy; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.datavec.api.transform.Transform; +import org.datavec.api.transform.analysis.columns.*; +import org.datavec.api.transform.condition.BooleanCondition; +import org.datavec.api.transform.condition.Condition; +import org.datavec.api.transform.condition.column.*; +import org.datavec.api.transform.condition.sequence.SequenceLengthCondition; +import org.datavec.api.transform.condition.string.StringRegexColumnCondition; +import org.datavec.api.transform.filter.ConditionFilter; +import org.datavec.api.transform.filter.Filter; +import org.datavec.api.transform.filter.FilterInvalidValues; +import org.datavec.api.transform.filter.InvalidNumColumns; +import org.datavec.api.transform.metadata.*; +import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform; +import org.datavec.api.transform.ndarray.NDArrayDistanceTransform; +import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform; +import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform; +import org.datavec.api.transform.rank.CalculateSortedRank; +import org.datavec.api.transform.schema.Schema; +import org.datavec.api.transform.schema.SequenceSchema; +import org.datavec.api.transform.sequence.ReduceSequenceTransform; +import org.datavec.api.transform.sequence.SequenceComparator; +import org.datavec.api.transform.sequence.SequenceSplit; +import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator; +import org.datavec.api.transform.sequence.comparator.StringComparator; +import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation; +import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence; +import org.datavec.api.transform.sequence.trim.SequenceTrimTransform; +import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction; +import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform; +import org.datavec.api.transform.sequence.window.TimeWindowFunction; +import org.datavec.api.transform.sequence.window.WindowFunction; +import org.datavec.api.transform.stringreduce.IStringReducer; +import org.datavec.api.transform.stringreduce.StringReducer; +import org.datavec.api.transform.transform.categorical.*; +import org.datavec.api.transform.transform.column.*; +import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform; +import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform; +import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault; +import org.datavec.api.transform.transform.doubletransform.*; +import org.datavec.api.transform.transform.integer.*; +import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform; +import org.datavec.api.transform.transform.longtransform.LongMathOpTransform; +import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; +import org.datavec.api.transform.transform.parse.ParseDoubleTransform; +import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform; +import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform; +import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform; +import org.datavec.api.transform.transform.string.*; +import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform; +import org.datavec.api.transform.transform.time.StringToTimeTransform; +import org.datavec.api.transform.transform.time.TimeMathOpTransform; +import org.datavec.api.writable.*; +import org.datavec.api.writable.comparator.*; +import org.nd4j.shade.jackson.annotation.JsonInclude; +import org.nd4j.shade.jackson.annotation.JsonSubTypes; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +/** + * This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override + * the existing annotations. + * In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field". + * We use these mixins to allow us to still load the old format + * + * @author Alex Black + */ +public class LegacyJsonFormat { + + private LegacyJsonFormat(){ } + + /** + * Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before + * @return Object mapper + */ + public static ObjectMapper legacyMapper(){ + ObjectMapper om = new ObjectMapper(); + om.addMixIn(Schema.class, SchemaMixin.class); + om.addMixIn(ColumnMetaData.class, ColumnMetaDataMixin.class); + om.addMixIn(Transform.class, TransformMixin.class); + om.addMixIn(Condition.class, ConditionMixin.class); + om.addMixIn(Writable.class, WritableMixin.class); + om.addMixIn(Filter.class, FilterMixin.class); + om.addMixIn(SequenceComparator.class, SequenceComparatorMixin.class); + om.addMixIn(SequenceSplit.class, SequenceSplitMixin.class); + om.addMixIn(WindowFunction.class, WindowFunctionMixin.class); + om.addMixIn(CalculateSortedRank.class, CalculateSortedRankMixin.class); + om.addMixIn(WritableComparator.class, WritableComparatorMixin.class); + om.addMixIn(ColumnAnalysis.class, ColumnAnalysisMixin.class); + om.addMixIn(IStringReducer.class, IStringReducerMixin.class); + return om; + } + + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes({@JsonSubTypes.Type(value = Schema.class, name = "Schema"), + @JsonSubTypes.Type(value = SequenceSchema.class, name = "SequenceSchema")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class SchemaMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes({@JsonSubTypes.Type(value = BinaryMetaData.class, name = "Binary"), + @JsonSubTypes.Type(value = BooleanMetaData.class, name = "Boloean"), + @JsonSubTypes.Type(value = CategoricalMetaData.class, name = "Categorical"), + @JsonSubTypes.Type(value = DoubleMetaData.class, name = "Double"), + @JsonSubTypes.Type(value = FloatMetaData.class, name = "Float"), + @JsonSubTypes.Type(value = IntegerMetaData.class, name = "Integer"), + @JsonSubTypes.Type(value = LongMetaData.class, name = "Long"), + @JsonSubTypes.Type(value = NDArrayMetaData.class, name = "NDArray"), + @JsonSubTypes.Type(value = StringMetaData.class, name = "String"), + @JsonSubTypes.Type(value = TimeMetaData.class, name = "Time")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ColumnMetaDataMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank"), + @JsonSubTypes.Type(value = CategoricalToIntegerTransform.class, name = "CategoricalToIntegerTransform"), + @JsonSubTypes.Type(value = CategoricalToOneHotTransform.class, name = "CategoricalToOneHotTransform"), + @JsonSubTypes.Type(value = IntegerToCategoricalTransform.class, name = "IntegerToCategoricalTransform"), + @JsonSubTypes.Type(value = StringToCategoricalTransform.class, name = "StringToCategoricalTransform"), + @JsonSubTypes.Type(value = DuplicateColumnsTransform.class, name = "DuplicateColumnsTransform"), + @JsonSubTypes.Type(value = RemoveColumnsTransform.class, name = "RemoveColumnsTransform"), + @JsonSubTypes.Type(value = RenameColumnsTransform.class, name = "RenameColumnsTransform"), + @JsonSubTypes.Type(value = ReorderColumnsTransform.class, name = "ReorderColumnsTransform"), + @JsonSubTypes.Type(value = ConditionalCopyValueTransform.class, name = "ConditionalCopyValueTransform"), + @JsonSubTypes.Type(value = ConditionalReplaceValueTransform.class, name = "ConditionalReplaceValueTransform"), + @JsonSubTypes.Type(value = ConditionalReplaceValueTransformWithDefault.class, name = "ConditionalReplaceValueTransformWithDefault"), + @JsonSubTypes.Type(value = DoubleColumnsMathOpTransform.class, name = "DoubleColumnsMathOpTransform"), + @JsonSubTypes.Type(value = DoubleMathOpTransform.class, name = "DoubleMathOpTransform"), + @JsonSubTypes.Type(value = Log2Normalizer.class, name = "Log2Normalizer"), + @JsonSubTypes.Type(value = MinMaxNormalizer.class, name = "MinMaxNormalizer"), + @JsonSubTypes.Type(value = StandardizeNormalizer.class, name = "StandardizeNormalizer"), + @JsonSubTypes.Type(value = SubtractMeanNormalizer.class, name = "SubtractMeanNormalizer"), + @JsonSubTypes.Type(value = IntegerColumnsMathOpTransform.class, name = "IntegerColumnsMathOpTransform"), + @JsonSubTypes.Type(value = IntegerMathOpTransform.class, name = "IntegerMathOpTransform"), + @JsonSubTypes.Type(value = ReplaceEmptyIntegerWithValueTransform.class, name = "ReplaceEmptyIntegerWithValueTransform"), + @JsonSubTypes.Type(value = ReplaceInvalidWithIntegerTransform.class, name = "ReplaceInvalidWithIntegerTransform"), + @JsonSubTypes.Type(value = LongColumnsMathOpTransform.class, name = "LongColumnsMathOpTransform"), + @JsonSubTypes.Type(value = LongMathOpTransform.class, name = "LongMathOpTransform"), + @JsonSubTypes.Type(value = MapAllStringsExceptListTransform.class, name = "MapAllStringsExceptListTransform"), + @JsonSubTypes.Type(value = RemoveWhiteSpaceTransform.class, name = "RemoveWhiteSpaceTransform"), + @JsonSubTypes.Type(value = ReplaceEmptyStringTransform.class, name = "ReplaceEmptyStringTransform"), + @JsonSubTypes.Type(value = ReplaceStringTransform.class, name = "ReplaceStringTransform"), + @JsonSubTypes.Type(value = StringListToCategoricalSetTransform.class, name = "StringListToCategoricalSetTransform"), + @JsonSubTypes.Type(value = StringMapTransform.class, name = "StringMapTransform"), + @JsonSubTypes.Type(value = DeriveColumnsFromTimeTransform.class, name = "DeriveColumnsFromTimeTransform"), + @JsonSubTypes.Type(value = StringToTimeTransform.class, name = "StringToTimeTransform"), + @JsonSubTypes.Type(value = TimeMathOpTransform.class, name = "TimeMathOpTransform"), + @JsonSubTypes.Type(value = ReduceSequenceByWindowTransform.class, name = "ReduceSequenceByWindowTransform"), + @JsonSubTypes.Type(value = DoubleMathFunctionTransform.class, name = "DoubleMathFunctionTransform"), + @JsonSubTypes.Type(value = AddConstantColumnTransform.class, name = "AddConstantColumnTransform"), + @JsonSubTypes.Type(value = RemoveAllColumnsExceptForTransform.class, name = "RemoveAllColumnsExceptForTransform"), + @JsonSubTypes.Type(value = ParseDoubleTransform.class, name = "ParseDoubleTransform"), + @JsonSubTypes.Type(value = ConvertToString.class, name = "ConvertToStringTransform"), + @JsonSubTypes.Type(value = AppendStringColumnTransform.class, name = "AppendStringColumnTransform"), + @JsonSubTypes.Type(value = SequenceDifferenceTransform.class, name = "SequenceDifferenceTransform"), + @JsonSubTypes.Type(value = ReduceSequenceTransform.class, name = "ReduceSequenceTransform"), + @JsonSubTypes.Type(value = SequenceMovingWindowReduceTransform.class, name = "SequenceMovingWindowReduceTransform"), + @JsonSubTypes.Type(value = IntegerToOneHotTransform.class, name = "IntegerToOneHotTransform"), + @JsonSubTypes.Type(value = SequenceTrimTransform.class, name = "SequenceTrimTransform"), + @JsonSubTypes.Type(value = SequenceOffsetTransform.class, name = "SequenceOffsetTransform"), + @JsonSubTypes.Type(value = NDArrayColumnsMathOpTransform.class, name = "NDArrayColumnsMathOpTransform"), + @JsonSubTypes.Type(value = NDArrayDistanceTransform.class, name = "NDArrayDistanceTransform"), + @JsonSubTypes.Type(value = NDArrayMathFunctionTransform.class, name = "NDArrayMathFunctionTransform"), + @JsonSubTypes.Type(value = NDArrayScalarOpTransform.class, name = "NDArrayScalarOpTransform"), + @JsonSubTypes.Type(value = ChangeCaseStringTransform.class, name = "ChangeCaseStringTransform"), + @JsonSubTypes.Type(value = ConcatenateStringColumns.class, name = "ConcatenateStringColumns"), + @JsonSubTypes.Type(value = StringListToCountsNDArrayTransform.class, name = "StringListToCountsNDArrayTransform"), + @JsonSubTypes.Type(value = StringListToIndicesNDArrayTransform.class, name = "StringListToIndicesNDArrayTransform"), + @JsonSubTypes.Type(value = PivotTransform.class, name = "PivotTransform"), + @JsonSubTypes.Type(value = TextToCharacterIndexTransform.class, name = "TextToCharacterIndexTransform")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class TransformMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = TrivialColumnCondition.class, name = "TrivialColumnCondition"), + @JsonSubTypes.Type(value = CategoricalColumnCondition.class, name = "CategoricalColumnCondition"), + @JsonSubTypes.Type(value = DoubleColumnCondition.class, name = "DoubleColumnCondition"), + @JsonSubTypes.Type(value = IntegerColumnCondition.class, name = "IntegerColumnCondition"), + @JsonSubTypes.Type(value = LongColumnCondition.class, name = "LongColumnCondition"), + @JsonSubTypes.Type(value = NullWritableColumnCondition.class, name = "NullWritableColumnCondition"), + @JsonSubTypes.Type(value = StringColumnCondition.class, name = "StringColumnCondition"), + @JsonSubTypes.Type(value = TimeColumnCondition.class, name = "TimeColumnCondition"), + @JsonSubTypes.Type(value = StringRegexColumnCondition.class, name = "StringRegexColumnCondition"), + @JsonSubTypes.Type(value = BooleanCondition.class, name = "BooleanCondition"), + @JsonSubTypes.Type(value = NaNColumnCondition.class, name = "NaNColumnCondition"), + @JsonSubTypes.Type(value = InfiniteColumnCondition.class, name = "InfiniteColumnCondition"), + @JsonSubTypes.Type(value = SequenceLengthCondition.class, name = "SequenceLengthCondition")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ConditionMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ArrayWritable.class, name = "ArrayWritable"), + @JsonSubTypes.Type(value = BooleanWritable.class, name = "BooleanWritable"), + @JsonSubTypes.Type(value = ByteWritable.class, name = "ByteWritable"), + @JsonSubTypes.Type(value = DoubleWritable.class, name = "DoubleWritable"), + @JsonSubTypes.Type(value = FloatWritable.class, name = "FloatWritable"), + @JsonSubTypes.Type(value = IntWritable.class, name = "IntWritable"), + @JsonSubTypes.Type(value = LongWritable.class, name = "LongWritable"), + @JsonSubTypes.Type(value = NullWritable.class, name = "NullWritable"), + @JsonSubTypes.Type(value = Text.class, name = "Text"), + @JsonSubTypes.Type(value = BytesWritable.class, name = "BytesWritable")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class WritableMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ConditionFilter.class, name = "ConditionFilter"), + @JsonSubTypes.Type(value = FilterInvalidValues.class, name = "FilterInvalidValues"), + @JsonSubTypes.Type(value = InvalidNumColumns.class, name = "InvalidNumCols")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class FilterMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = NumericalColumnComparator.class, name = "NumericalColumnComparator"), + @JsonSubTypes.Type(value = StringComparator.class, name = "StringComparator")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class SequenceComparatorMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = SequenceSplitTimeSeparation.class, name = "SequenceSplitTimeSeparation"), + @JsonSubTypes.Type(value = SplitMaxLengthSequence.class, name = "SplitMaxLengthSequence")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class SequenceSplitMixin { } + + @JsonInclude(JsonInclude.Include.NON_NULL) + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = TimeWindowFunction.class, name = "TimeWindowFunction"), + @JsonSubTypes.Type(value = OverlappingTimeWindowFunction.class, name = "OverlappingTimeWindowFunction")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class WindowFunctionMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = CalculateSortedRank.class, name = "CalculateSortedRank")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class CalculateSortedRankMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = DoubleWritableComparator.class, name = "DoubleWritableComparator"), + @JsonSubTypes.Type(value = FloatWritableComparator.class, name = "FloatWritableComparator"), + @JsonSubTypes.Type(value = IntWritableComparator.class, name = "IntWritableComparator"), + @JsonSubTypes.Type(value = LongWritableComparator.class, name = "LongWritableComparator"), + @JsonSubTypes.Type(value = TextWritableComparator.class, name = "TextWritableComparator")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class WritableComparatorMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = BytesAnalysis.class, name = "BytesAnalysis"), + @JsonSubTypes.Type(value = CategoricalAnalysis.class, name = "CategoricalAnalysis"), + @JsonSubTypes.Type(value = DoubleAnalysis.class, name = "DoubleAnalysis"), + @JsonSubTypes.Type(value = IntegerAnalysis.class, name = "IntegerAnalysis"), + @JsonSubTypes.Type(value = LongAnalysis.class, name = "LongAnalysis"), + @JsonSubTypes.Type(value = StringAnalysis.class, name = "StringAnalysis"), + @JsonSubTypes.Type(value = TimeAnalysis.class, name = "TimeAnalysis")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ColumnAnalysisMixin{ } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = StringReducer.class, name = "StringReducer")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class IStringReducerMixin{ } +} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java deleted file mode 100644 index c4b478278..000000000 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyMappingHelper.java +++ /dev/null @@ -1,535 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.api.transform.serde.legacy; - -import org.datavec.api.transform.Transform; -import org.datavec.api.transform.analysis.columns.*; -import org.datavec.api.transform.condition.BooleanCondition; -import org.datavec.api.transform.condition.Condition; -import org.datavec.api.transform.condition.column.*; -import org.datavec.api.transform.condition.sequence.SequenceLengthCondition; -import org.datavec.api.transform.condition.string.StringRegexColumnCondition; -import org.datavec.api.transform.filter.ConditionFilter; -import org.datavec.api.transform.filter.Filter; -import org.datavec.api.transform.filter.FilterInvalidValues; -import org.datavec.api.transform.filter.InvalidNumColumns; -import org.datavec.api.transform.metadata.*; -import org.datavec.api.transform.ndarray.NDArrayColumnsMathOpTransform; -import org.datavec.api.transform.ndarray.NDArrayDistanceTransform; -import org.datavec.api.transform.ndarray.NDArrayMathFunctionTransform; -import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform; -import org.datavec.api.transform.rank.CalculateSortedRank; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.transform.sequence.ReduceSequenceTransform; -import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator; -import org.datavec.api.transform.sequence.comparator.StringComparator; -import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation; -import org.datavec.api.transform.sequence.split.SplitMaxLengthSequence; -import org.datavec.api.transform.sequence.trim.SequenceTrimTransform; -import org.datavec.api.transform.sequence.window.OverlappingTimeWindowFunction; -import org.datavec.api.transform.sequence.window.ReduceSequenceByWindowTransform; -import org.datavec.api.transform.sequence.window.TimeWindowFunction; -import org.datavec.api.transform.sequence.window.WindowFunction; -import org.datavec.api.transform.stringreduce.IStringReducer; -import org.datavec.api.transform.stringreduce.StringReducer; -import org.datavec.api.transform.transform.categorical.*; -import org.datavec.api.transform.transform.column.*; -import org.datavec.api.transform.transform.condition.ConditionalCopyValueTransform; -import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransform; -import org.datavec.api.transform.transform.condition.ConditionalReplaceValueTransformWithDefault; -import org.datavec.api.transform.transform.doubletransform.*; -import org.datavec.api.transform.transform.integer.*; -import org.datavec.api.transform.transform.longtransform.LongColumnsMathOpTransform; -import org.datavec.api.transform.transform.longtransform.LongMathOpTransform; -import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; -import org.datavec.api.transform.transform.nlp.TextToTermIndexSequenceTransform; -import org.datavec.api.transform.transform.parse.ParseDoubleTransform; -import org.datavec.api.transform.transform.sequence.SequenceDifferenceTransform; -import org.datavec.api.transform.transform.sequence.SequenceMovingWindowReduceTransform; -import org.datavec.api.transform.transform.sequence.SequenceOffsetTransform; -import org.datavec.api.transform.transform.string.*; -import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform; -import org.datavec.api.transform.transform.time.StringToTimeTransform; -import org.datavec.api.transform.transform.time.TimeMathOpTransform; -import org.datavec.api.writable.*; -import org.datavec.api.writable.comparator.*; -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -import java.util.HashMap; -import java.util.Map; - -public class LegacyMappingHelper { - - public static Map legacyMappingForClass(Class c){ - //Need to be able to get the map - and they need to be mutable... - switch (c.getSimpleName()){ - case "Transform": - return getLegacyMappingImageTransform(); - case "ColumnAnalysis": - return getLegacyMappingColumnAnalysis(); - case "Condition": - return getLegacyMappingCondition(); - case "Filter": - return getLegacyMappingFilter(); - case "ColumnMetaData": - return mapColumnMetaData; - case "CalculateSortedRank": - return mapCalculateSortedRank; - case "Schema": - return mapSchema; - case "SequenceComparator": - return mapSequenceComparator; - case "SequenceSplit": - return mapSequenceSplit; - case "WindowFunction": - return mapWindowFunction; - case "IStringReducer": - return mapIStringReducer; - case "Writable": - return mapWritable; - case "WritableComparator": - return mapWritableComparator; - case "ImageTransform": - return mapImageTransform; - default: - //Should never happen - throw new IllegalArgumentException("No legacy mapping available for class " + c.getName()); - } - } - - private static Map mapTransform; - private static Map mapColumnAnalysis; - private static Map mapCondition; - private static Map mapFilter; - private static Map mapColumnMetaData; - private static Map mapCalculateSortedRank; - private static Map mapSchema; - private static Map mapSequenceComparator; - private static Map mapSequenceSplit; - private static Map mapWindowFunction; - private static Map mapIStringReducer; - private static Map mapWritable; - private static Map mapWritableComparator; - private static Map mapImageTransform; - - private static synchronized Map getLegacyMappingTransform(){ - - if(mapTransform == null) { - //The following classes all used their class short name - Map m = new HashMap<>(); - m.put("CategoricalToIntegerTransform", CategoricalToIntegerTransform.class.getName()); - m.put("CategoricalToOneHotTransform", CategoricalToOneHotTransform.class.getName()); - m.put("IntegerToCategoricalTransform", IntegerToCategoricalTransform.class.getName()); - m.put("StringToCategoricalTransform", StringToCategoricalTransform.class.getName()); - m.put("DuplicateColumnsTransform", DuplicateColumnsTransform.class.getName()); - m.put("RemoveColumnsTransform", RemoveColumnsTransform.class.getName()); - m.put("RenameColumnsTransform", RenameColumnsTransform.class.getName()); - m.put("ReorderColumnsTransform", ReorderColumnsTransform.class.getName()); - m.put("ConditionalCopyValueTransform", ConditionalCopyValueTransform.class.getName()); - m.put("ConditionalReplaceValueTransform", ConditionalReplaceValueTransform.class.getName()); - m.put("ConditionalReplaceValueTransformWithDefault", ConditionalReplaceValueTransformWithDefault.class.getName()); - m.put("DoubleColumnsMathOpTransform", DoubleColumnsMathOpTransform.class.getName()); - m.put("DoubleMathOpTransform", DoubleMathOpTransform.class.getName()); - m.put("Log2Normalizer", Log2Normalizer.class.getName()); - m.put("MinMaxNormalizer", MinMaxNormalizer.class.getName()); - m.put("StandardizeNormalizer", StandardizeNormalizer.class.getName()); - m.put("SubtractMeanNormalizer", SubtractMeanNormalizer.class.getName()); - m.put("IntegerColumnsMathOpTransform", IntegerColumnsMathOpTransform.class.getName()); - m.put("IntegerMathOpTransform", IntegerMathOpTransform.class.getName()); - m.put("ReplaceEmptyIntegerWithValueTransform", ReplaceEmptyIntegerWithValueTransform.class.getName()); - m.put("ReplaceInvalidWithIntegerTransform", ReplaceInvalidWithIntegerTransform.class.getName()); - m.put("LongColumnsMathOpTransform", LongColumnsMathOpTransform.class.getName()); - m.put("LongMathOpTransform", LongMathOpTransform.class.getName()); - m.put("MapAllStringsExceptListTransform", MapAllStringsExceptListTransform.class.getName()); - m.put("RemoveWhiteSpaceTransform", RemoveWhiteSpaceTransform.class.getName()); - m.put("ReplaceEmptyStringTransform", ReplaceEmptyStringTransform.class.getName()); - m.put("ReplaceStringTransform", ReplaceStringTransform.class.getName()); - m.put("StringListToCategoricalSetTransform", StringListToCategoricalSetTransform.class.getName()); - m.put("StringMapTransform", StringMapTransform.class.getName()); - m.put("DeriveColumnsFromTimeTransform", DeriveColumnsFromTimeTransform.class.getName()); - m.put("StringToTimeTransform", StringToTimeTransform.class.getName()); - m.put("TimeMathOpTransform", TimeMathOpTransform.class.getName()); - m.put("ReduceSequenceByWindowTransform", ReduceSequenceByWindowTransform.class.getName()); - m.put("DoubleMathFunctionTransform", DoubleMathFunctionTransform.class.getName()); - m.put("AddConstantColumnTransform", AddConstantColumnTransform.class.getName()); - m.put("RemoveAllColumnsExceptForTransform", RemoveAllColumnsExceptForTransform.class.getName()); - m.put("ParseDoubleTransform", ParseDoubleTransform.class.getName()); - m.put("ConvertToStringTransform", ConvertToString.class.getName()); - m.put("AppendStringColumnTransform", AppendStringColumnTransform.class.getName()); - m.put("SequenceDifferenceTransform", SequenceDifferenceTransform.class.getName()); - m.put("ReduceSequenceTransform", ReduceSequenceTransform.class.getName()); - m.put("SequenceMovingWindowReduceTransform", SequenceMovingWindowReduceTransform.class.getName()); - m.put("IntegerToOneHotTransform", IntegerToOneHotTransform.class.getName()); - m.put("SequenceTrimTransform", SequenceTrimTransform.class.getName()); - m.put("SequenceOffsetTransform", SequenceOffsetTransform.class.getName()); - m.put("NDArrayColumnsMathOpTransform", NDArrayColumnsMathOpTransform.class.getName()); - m.put("NDArrayDistanceTransform", NDArrayDistanceTransform.class.getName()); - m.put("NDArrayMathFunctionTransform", NDArrayMathFunctionTransform.class.getName()); - m.put("NDArrayScalarOpTransform", NDArrayScalarOpTransform.class.getName()); - m.put("ChangeCaseStringTransform", ChangeCaseStringTransform.class.getName()); - m.put("ConcatenateStringColumns", ConcatenateStringColumns.class.getName()); - m.put("StringListToCountsNDArrayTransform", StringListToCountsNDArrayTransform.class.getName()); - m.put("StringListToIndicesNDArrayTransform", StringListToIndicesNDArrayTransform.class.getName()); - m.put("PivotTransform", PivotTransform.class.getName()); - m.put("TextToCharacterIndexTransform", TextToCharacterIndexTransform.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(TextToTermIndexSequenceTransform.class.getSimpleName(), TextToTermIndexSequenceTransform.class.getName()); - m.put(ConvertToInteger.class.getSimpleName(), ConvertToInteger.class.getName()); - m.put(ConvertToDouble.class.getSimpleName(), ConvertToDouble.class.getName()); - - mapTransform = m; - } - - return mapTransform; - } - - private static Map getLegacyMappingColumnAnalysis(){ - if(mapColumnAnalysis == null) { - Map m = new HashMap<>(); - m.put("BytesAnalysis", BytesAnalysis.class.getName()); - m.put("CategoricalAnalysis", CategoricalAnalysis.class.getName()); - m.put("DoubleAnalysis", DoubleAnalysis.class.getName()); - m.put("IntegerAnalysis", IntegerAnalysis.class.getName()); - m.put("LongAnalysis", LongAnalysis.class.getName()); - m.put("StringAnalysis", StringAnalysis.class.getName()); - m.put("TimeAnalysis", TimeAnalysis.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(NDArrayAnalysis.class.getSimpleName(), NDArrayAnalysis.class.getName()); - - mapColumnAnalysis = m; - } - - return mapColumnAnalysis; - } - - private static Map getLegacyMappingCondition(){ - if(mapCondition == null) { - Map m = new HashMap<>(); - m.put("TrivialColumnCondition", TrivialColumnCondition.class.getName()); - m.put("CategoricalColumnCondition", CategoricalColumnCondition.class.getName()); - m.put("DoubleColumnCondition", DoubleColumnCondition.class.getName()); - m.put("IntegerColumnCondition", IntegerColumnCondition.class.getName()); - m.put("LongColumnCondition", LongColumnCondition.class.getName()); - m.put("NullWritableColumnCondition", NullWritableColumnCondition.class.getName()); - m.put("StringColumnCondition", StringColumnCondition.class.getName()); - m.put("TimeColumnCondition", TimeColumnCondition.class.getName()); - m.put("StringRegexColumnCondition", StringRegexColumnCondition.class.getName()); - m.put("BooleanCondition", BooleanCondition.class.getName()); - m.put("NaNColumnCondition", NaNColumnCondition.class.getName()); - m.put("InfiniteColumnCondition", InfiniteColumnCondition.class.getName()); - m.put("SequenceLengthCondition", SequenceLengthCondition.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(InvalidValueColumnCondition.class.getSimpleName(), InvalidValueColumnCondition.class.getName()); - m.put(BooleanColumnCondition.class.getSimpleName(), BooleanColumnCondition.class.getName()); - - mapCondition = m; - } - - return mapCondition; - } - - private static Map getLegacyMappingFilter(){ - if(mapFilter == null) { - Map m = new HashMap<>(); - m.put("ConditionFilter", ConditionFilter.class.getName()); - m.put("FilterInvalidValues", FilterInvalidValues.class.getName()); - m.put("InvalidNumCols", InvalidNumColumns.class.getName()); - - mapFilter = m; - } - return mapFilter; - } - - private static Map getLegacyMappingColumnMetaData(){ - if(mapColumnMetaData == null) { - Map m = new HashMap<>(); - m.put("Categorical", CategoricalMetaData.class.getName()); - m.put("Double", DoubleMetaData.class.getName()); - m.put("Float", FloatMetaData.class.getName()); - m.put("Integer", IntegerMetaData.class.getName()); - m.put("Long", LongMetaData.class.getName()); - m.put("String", StringMetaData.class.getName()); - m.put("Time", TimeMetaData.class.getName()); - m.put("NDArray", NDArrayMetaData.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(BooleanMetaData.class.getSimpleName(), BooleanMetaData.class.getName()); - m.put(BinaryMetaData.class.getSimpleName(), BinaryMetaData.class.getName()); - - mapColumnMetaData = m; - } - - return mapColumnMetaData; - } - - private static Map getLegacyMappingCalculateSortedRank(){ - if(mapCalculateSortedRank == null) { - Map m = new HashMap<>(); - m.put("CalculateSortedRank", CalculateSortedRank.class.getName()); - mapCalculateSortedRank = m; - } - return mapCalculateSortedRank; - } - - private static Map getLegacyMappingSchema(){ - if(mapSchema == null) { - Map m = new HashMap<>(); - m.put("Schema", Schema.class.getName()); - m.put("SequenceSchema", SequenceSchema.class.getName()); - - mapSchema = m; - } - return mapSchema; - } - - private static Map getLegacyMappingSequenceComparator(){ - if(mapSequenceComparator == null) { - Map m = new HashMap<>(); - m.put("NumericalColumnComparator", NumericalColumnComparator.class.getName()); - m.put("StringComparator", StringComparator.class.getName()); - - mapSequenceComparator = m; - } - return mapSequenceComparator; - } - - private static Map getLegacyMappingSequenceSplit(){ - if(mapSequenceSplit == null) { - Map m = new HashMap<>(); - m.put("SequenceSplitTimeSeparation", SequenceSplitTimeSeparation.class.getName()); - m.put("SplitMaxLengthSequence", SplitMaxLengthSequence.class.getName()); - - mapSequenceSplit = m; - } - return mapSequenceSplit; - } - - private static Map getLegacyMappingWindowFunction(){ - if(mapWindowFunction == null) { - Map m = new HashMap<>(); - m.put("TimeWindowFunction", TimeWindowFunction.class.getName()); - m.put("OverlappingTimeWindowFunction", OverlappingTimeWindowFunction.class.getName()); - - mapWindowFunction = m; - } - return mapWindowFunction; - } - - private static Map getLegacyMappingIStringReducer(){ - if(mapIStringReducer == null) { - Map m = new HashMap<>(); - m.put("StringReducer", StringReducer.class.getName()); - - mapIStringReducer = m; - } - return mapIStringReducer; - } - - private static Map getLegacyMappingWritable(){ - if (mapWritable == null) { - Map m = new HashMap<>(); - m.put("ArrayWritable", ArrayWritable.class.getName()); - m.put("BooleanWritable", BooleanWritable.class.getName()); - m.put("ByteWritable", ByteWritable.class.getName()); - m.put("DoubleWritable", DoubleWritable.class.getName()); - m.put("FloatWritable", FloatWritable.class.getName()); - m.put("IntWritable", IntWritable.class.getName()); - m.put("LongWritable", LongWritable.class.getName()); - m.put("NullWritable", NullWritable.class.getName()); - m.put("Text", Text.class.getName()); - m.put("BytesWritable", BytesWritable.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(NDArrayWritable.class.getSimpleName(), NDArrayWritable.class.getName()); - - mapWritable = m; - } - - return mapWritable; - } - - private static Map getLegacyMappingWritableComparator(){ - if(mapWritableComparator == null) { - Map m = new HashMap<>(); - m.put("DoubleWritableComparator", DoubleWritableComparator.class.getName()); - m.put("FloatWritableComparator", FloatWritableComparator.class.getName()); - m.put("IntWritableComparator", IntWritableComparator.class.getName()); - m.put("LongWritableComparator", LongWritableComparator.class.getName()); - m.put("TextWritableComparator", TextWritableComparator.class.getName()); - - //The following never had subtype annotations, and hence will have had the default name: - m.put(ByteWritable.Comparator.class.getSimpleName(), ByteWritable.Comparator.class.getName()); - m.put(FloatWritable.Comparator.class.getSimpleName(), FloatWritable.Comparator.class.getName()); - m.put(IntWritable.Comparator.class.getSimpleName(), IntWritable.Comparator.class.getName()); - m.put(BooleanWritable.Comparator.class.getSimpleName(), BooleanWritable.Comparator.class.getName()); - m.put(LongWritable.Comparator.class.getSimpleName(), LongWritable.Comparator.class.getName()); - m.put(Text.Comparator.class.getSimpleName(), Text.Comparator.class.getName()); - m.put(LongWritable.DecreasingComparator.class.getSimpleName(), LongWritable.DecreasingComparator.class.getName()); - m.put(DoubleWritable.Comparator.class.getSimpleName(), DoubleWritable.Comparator.class.getName()); - - mapWritableComparator = m; - } - - return mapWritableComparator; - } - - public static Map getLegacyMappingImageTransform(){ - if(mapImageTransform == null) { - Map m = new HashMap<>(); - m.put("EqualizeHistTransform", "org.datavec.image.transform.EqualizeHistTransform"); - m.put("RotateImageTransform", "org.datavec.image.transform.RotateImageTransform"); - m.put("ColorConversionTransform", "org.datavec.image.transform.ColorConversionTransform"); - m.put("WarpImageTransform", "org.datavec.image.transform.WarpImageTransform"); - m.put("BoxImageTransform", "org.datavec.image.transform.BoxImageTransform"); - m.put("CropImageTransform", "org.datavec.image.transform.CropImageTransform"); - m.put("FilterImageTransform", "org.datavec.image.transform.FilterImageTransform"); - m.put("FlipImageTransform", "org.datavec.image.transform.FlipImageTransform"); - m.put("LargestBlobCropTransform", "org.datavec.image.transform.LargestBlobCropTransform"); - m.put("ResizeImageTransform", "org.datavec.image.transform.ResizeImageTransform"); - m.put("RandomCropTransform", "org.datavec.image.transform.RandomCropTransform"); - m.put("ScaleImageTransform", "org.datavec.image.transform.ScaleImageTransform"); - - mapImageTransform = m; - } - return mapImageTransform; - } - - @JsonDeserialize(using = LegacyTransformDeserializer.class) - public static class TransformHelper { } - - public static class LegacyTransformDeserializer extends GenericLegacyDeserializer { - public LegacyTransformDeserializer() { - super(Transform.class, getLegacyMappingTransform()); - } - } - - @JsonDeserialize(using = LegacyColumnAnalysisDeserializer.class) - public static class ColumnAnalysisHelper { } - - public static class LegacyColumnAnalysisDeserializer extends GenericLegacyDeserializer { - public LegacyColumnAnalysisDeserializer() { - super(ColumnAnalysis.class, getLegacyMappingColumnAnalysis()); - } - } - - @JsonDeserialize(using = LegacyConditionDeserializer.class) - public static class ConditionHelper { } - - public static class LegacyConditionDeserializer extends GenericLegacyDeserializer { - public LegacyConditionDeserializer() { - super(Condition.class, getLegacyMappingCondition()); - } - } - - @JsonDeserialize(using = LegacyFilterDeserializer.class) - public static class FilterHelper { } - - public static class LegacyFilterDeserializer extends GenericLegacyDeserializer { - public LegacyFilterDeserializer() { - super(Filter.class, getLegacyMappingFilter()); - } - } - - @JsonDeserialize(using = LegacyColumnMetaDataDeserializer.class) - public static class ColumnMetaDataHelper { } - - public static class LegacyColumnMetaDataDeserializer extends GenericLegacyDeserializer { - public LegacyColumnMetaDataDeserializer() { - super(ColumnMetaData.class, getLegacyMappingColumnMetaData()); - } - } - - @JsonDeserialize(using = LegacyCalculateSortedRankDeserializer.class) - public static class CalculateSortedRankHelper { } - - public static class LegacyCalculateSortedRankDeserializer extends GenericLegacyDeserializer { - public LegacyCalculateSortedRankDeserializer() { - super(CalculateSortedRank.class, getLegacyMappingCalculateSortedRank()); - } - } - - @JsonDeserialize(using = LegacySchemaDeserializer.class) - public static class SchemaHelper { } - - public static class LegacySchemaDeserializer extends GenericLegacyDeserializer { - public LegacySchemaDeserializer() { - super(Schema.class, getLegacyMappingSchema()); - } - } - - @JsonDeserialize(using = LegacySequenceComparatorDeserializer.class) - public static class SequenceComparatorHelper { } - - public static class LegacySequenceComparatorDeserializer extends GenericLegacyDeserializer { - public LegacySequenceComparatorDeserializer() { - super(SequenceComparator.class, getLegacyMappingSequenceComparator()); - } - } - - @JsonDeserialize(using = LegacySequenceSplitDeserializer.class) - public static class SequenceSplitHelper { } - - public static class LegacySequenceSplitDeserializer extends GenericLegacyDeserializer { - public LegacySequenceSplitDeserializer() { - super(SequenceSplit.class, getLegacyMappingSequenceSplit()); - } - } - - @JsonDeserialize(using = LegacyWindowFunctionDeserializer.class) - public static class WindowFunctionHelper { } - - public static class LegacyWindowFunctionDeserializer extends GenericLegacyDeserializer { - public LegacyWindowFunctionDeserializer() { - super(WindowFunction.class, getLegacyMappingWindowFunction()); - } - } - - - @JsonDeserialize(using = LegacyIStringReducerDeserializer.class) - public static class IStringReducerHelper { } - - public static class LegacyIStringReducerDeserializer extends GenericLegacyDeserializer { - public LegacyIStringReducerDeserializer() { - super(IStringReducer.class, getLegacyMappingIStringReducer()); - } - } - - - @JsonDeserialize(using = LegacyWritableDeserializer.class) - public static class WritableHelper { } - - public static class LegacyWritableDeserializer extends GenericLegacyDeserializer { - public LegacyWritableDeserializer() { - super(Writable.class, getLegacyMappingWritable()); - } - } - - @JsonDeserialize(using = LegacyWritableComparatorDeserializer.class) - public static class WritableComparatorHelper { } - - public static class LegacyWritableComparatorDeserializer extends GenericLegacyDeserializer { - public LegacyWritableComparatorDeserializer() { - super(WritableComparator.class, getLegacyMappingWritableComparator()); - } - } -} diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java index f43e189d8..54bd7b7c8 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java @@ -17,7 +17,6 @@ package org.datavec.api.transform.stringreduce; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -31,8 +30,7 @@ import java.util.List; * a single List */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.IStringReducerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface IStringReducer extends Serializable { /** diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index e3ba797c8..c55d4d3bb 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -16,7 +16,7 @@ package org.datavec.api.util.ndarray; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import lombok.NonNull; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java index 2584d8b9f..ae5e3a567 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java index 72f4b3c5b..39f41c076 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java index 1b54e7d54..f0ab62bef 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java index 3803c8098..1c127e0f8 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java index 58cd45829..4a767a183 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java @@ -17,7 +17,7 @@ package org.datavec.api.writable; -import com.google.common.math.DoubleMath; +import org.nd4j.shade.guava.math.DoubleMath; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java index 30eb5e25e..5085dd3f2 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Writable.java @@ -16,7 +16,6 @@ package org.datavec.api.writable; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.io.DataInput; @@ -60,8 +59,7 @@ import java.io.Serializable; * } *

*/ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.WritableHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Writable extends Serializable { /** * Serialize the fields of this object to out. diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java index 4d638124e..0a5ddddb8 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java @@ -16,7 +16,7 @@ package org.datavec.api.writable.batch; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Data; import lombok.NonNull; import org.datavec.api.writable.NDArrayWritable; diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java index 07ef7ee56..b8e540f61 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java @@ -16,16 +16,13 @@ package org.datavec.api.writable.comparator; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; import org.datavec.api.writable.Writable; -import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.io.Serializable; import java.util.Comparator; -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyMappingHelper.WritableComparatorHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface WritableComparator extends Comparator, Serializable { } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index 647c0af65..e8ce37bd3 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -16,7 +16,7 @@ package org.datavec.api.split; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.labels.ParentPathLabelGenerator; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java index e67fc1f61..c9fb57eb9 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java @@ -16,7 +16,7 @@ package org.datavec.api.split.parittion; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.datavec.api.conf.Configuration; import org.datavec.api.split.FileSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java index c98b58381..dff90f8b9 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java @@ -78,8 +78,9 @@ public class TestJsonYaml { public void testMissingPrimitives() { Schema schema = new Schema.Builder().addColumnDouble("Dbl2", null, 100.0, false, false).build(); - - String strJson = "{\n" + " \"Schema\" : {\n" + " \"columns\" : [ {\n" + " \"Double\" : {\n" + //Legacy format JSON + String strJson = "{\n" + " \"Schema\" : {\n" + + " \"columns\" : [ {\n" + " \"Double\" : {\n" + " \"name\" : \"Dbl2\",\n" + " \"maxAllowedValue\" : 100.0\n" + //" \"allowNaN\" : false,\n" + //Normally included: but exclude here to test //" \"allowInfinite\" : false\n" + //Normally included: but exclude here to test diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index ce7d779dc..6dfacdd93 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -16,7 +16,7 @@ package org.datavec.api.writable; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.junit.Test; diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml index 4bda93a5e..645971a45 100644 --- a/datavec/datavec-arrow/pom.xml +++ b/datavec/datavec-arrow/pom.xml @@ -34,42 +34,6 @@ nd4j-arrow ${project.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - - - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-xml - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-joda - ${spark2.jackson.version} - org.datavec datavec-api @@ -80,11 +44,6 @@ hppc ${hppc.version} - - com.google.guava - guava - ${guava.version} - org.apache.arrow arrow-vector diff --git a/datavec/datavec-data/datavec-data-codec/pom.xml b/datavec/datavec-data/datavec-data-codec/pom.xml index 6a6bceda6..2a65cb41a 100644 --- a/datavec/datavec-data/datavec-data-codec/pom.xml +++ b/datavec/datavec-data/datavec-data-codec/pom.xml @@ -43,12 +43,6 @@ datavec-api ${project.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - - test-nd4j-native diff --git a/datavec/datavec-data/datavec-data-image/pom.xml b/datavec/datavec-data/datavec-data-image/pom.xml index 3b8664650..b97c9d5c6 100644 --- a/datavec/datavec-data/datavec-data-image/pom.xml +++ b/datavec/datavec-data/datavec-data-image/pom.xml @@ -31,20 +31,12 @@ datavec-api ${project.version} - ch.qos.logback logback-classic ${logback.version} test - - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.nd4j nd4j-buffer @@ -75,7 +67,6 @@ imageio-bmp 3.1.1 - com.google.android android @@ -88,7 +79,6 @@ true - org.bytedeco javacpp @@ -99,25 +89,21 @@ javacv ${javacv.version} - org.bytedeco opencv-platform ${opencv.version}-${javacpp-presets.version} - org.bytedeco leptonica-platform ${leptonica.version}-${javacpp-presets.version} - org.bytedeco hdf5-platform ${hdf5.version}-${javacpp-presets.version} - @@ -143,5 +129,4 @@ test-nd4j-cuda-10.1 - diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index d80f12a2e..a962dd84a 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -16,7 +16,7 @@ package org.datavec.image.recordreader; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java deleted file mode 100644 index 5e7b09c12..000000000 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/serde/LegacyImageMappingHelper.java +++ /dev/null @@ -1,35 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.image.serde; - -import org.datavec.api.transform.serde.legacy.GenericLegacyDeserializer; -import org.datavec.api.transform.serde.legacy.LegacyMappingHelper; -import org.datavec.image.transform.ImageTransform; -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -public class LegacyImageMappingHelper { - - @JsonDeserialize(using = LegacyImageTransformDeserializer.class) - public static class ImageTransformHelper { } - - public static class LegacyImageTransformDeserializer extends GenericLegacyDeserializer { - public LegacyImageTransformDeserializer() { - super(ImageTransform.class, LegacyMappingHelper.getLegacyMappingImageTransform()); - } - } - -} diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java index 39239494b..afcdf894f 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java @@ -16,11 +16,8 @@ package org.datavec.image.transform; -import lombok.Data; import org.datavec.image.data.ImageWritable; -import org.datavec.image.serde.LegacyImageMappingHelper; import org.nd4j.shade.jackson.annotation.JsonInclude; -import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.util.Random; @@ -31,8 +28,7 @@ import java.util.Random; * @author saudet */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyImageMappingHelper.ImageTransformHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ImageTransform { /** diff --git a/datavec/datavec-data/datavec-data-nlp/pom.xml b/datavec/datavec-data/datavec-data-nlp/pom.xml index 6c6589ffe..17ad11211 100644 --- a/datavec/datavec-data/datavec-data-nlp/pom.xml +++ b/datavec/datavec-data/datavec-data-nlp/pom.xml @@ -31,7 +31,6 @@ UTF-8 2.0.0 - @@ -75,13 +74,6 @@ cleartk-opennlp-tools ${cleartk.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - - ch.qos.logback logback-classic diff --git a/datavec/datavec-hadoop/pom.xml b/datavec/datavec-hadoop/pom.xml index 787b64052..c95e6d3bc 100644 --- a/datavec/datavec-hadoop/pom.xml +++ b/datavec/datavec-hadoop/pom.xml @@ -50,11 +50,6 @@ netty ${netty.version} - - com.google.guava - guava - ${guava.version} - org.apache.commons commons-compress @@ -95,14 +90,6 @@ - - - - org.nd4j - nd4j-native - ${nd4j.version} - test - diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java index 01bef8fa9..58f7a57db 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.reader; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java index cf1a801f5..1cbe47176 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.reader; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java index 992c91312..faf41cbb4 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.reader; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.*; diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java index d35becd7a..7cd112c63 100644 --- a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java +++ b/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java @@ -16,7 +16,7 @@ package org.datavec.hadoop.records.writer; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.datavec.api.records.converter.RecordReaderConverter; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.SequenceRecordReader; diff --git a/datavec/datavec-local/pom.xml b/datavec/datavec-local/pom.xml index 6118d3ecf..f286eeb95 100644 --- a/datavec/datavec-local/pom.xml +++ b/datavec/datavec-local/pom.xml @@ -50,12 +50,6 @@ protonpack ${protonpack.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.datavec datavec-api diff --git a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java index 5360eae7e..276d79f88 100644 --- a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java +++ b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java @@ -16,7 +16,7 @@ package org.datavec.local.transforms.join; -import com.google.common.collect.Iterables; +import org.nd4j.shade.guava.collect.Iterables; import org.datavec.api.transform.join.Join; import org.datavec.api.writable.Writable; import org.datavec.local.transforms.functions.FlatMapFunctionAdapter; diff --git a/datavec/datavec-perf/pom.xml b/datavec/datavec-perf/pom.xml index 59aa0be79..fb4eaaa89 100644 --- a/datavec/datavec-perf/pom.xml +++ b/datavec/datavec-perf/pom.xml @@ -52,12 +52,6 @@ ${project.version} test - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.datavec datavec-api diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 0b147bc7a..e60bc9219 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -26,10 +26,8 @@ datavec-python - - - + com.googlecode.json-simple json-simple 1.1 @@ -39,11 +37,6 @@ cpython-platform ${cpython-platform.version} - - org.nd4j - nd4j-native - ${project.version} - com.google.code.findbugs jsr305 @@ -54,14 +47,19 @@ datavec-api ${project.version} - ch.qos.logback logback-classic ${logback.version} test + + org.nd4j + nd4j-native-api + ${project.version} + + test-nd4j-native @@ -70,5 +68,4 @@ test-nd4j-cuda-10.1 - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml index d27da85df..db110703b 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml @@ -30,12 +30,6 @@ - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.nd4j nd4j-jackson diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml index 7483b1b3d..ade8acd14 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml @@ -35,12 +35,6 @@ datavec-api ${datavec.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.datavec datavec-data-image diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 2cc01e288..605b13b70 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -64,12 +64,6 @@ ${datavec.version} - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} - - joda-time joda-time @@ -106,40 +100,10 @@ ${snakeyaml.version} - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${jackson.version} - - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${jackson.version} - - com.typesafe.play play-java_2.11 - ${play.version} + ${playframework.version} com.google.code.findbugs @@ -161,25 +125,31 @@ com.typesafe.play play-json_2.11 - ${play.version} + ${playframework.version} com.typesafe.play play-server_2.11 - ${play.version} + ${playframework.version} com.typesafe.play play_2.11 - ${play.version} + ${playframework.version} com.typesafe.play play-netty-server_2.11 - ${play.version} + ${playframework.version} + + + + com.typesafe.akka + akka-cluster_2.11 + 2.5.23 @@ -195,14 +165,11 @@ ${jcommander.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test + org.apache.spark + spark-core_2.11 + ${spark.version} - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java index 893ce4218..f20799905 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/CSVSparkTransformServer.java @@ -24,12 +24,16 @@ import org.apache.commons.io.FileUtils; import org.datavec.api.transform.TransformProcess; import org.datavec.image.transform.ImageTransformProcess; import org.datavec.spark.transform.model.*; +import play.BuiltInComponents; import play.Mode; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; import java.io.File; import java.io.IOException; +import java.util.Base64; +import java.util.Random; import static play.mvc.Results.*; @@ -66,9 +70,6 @@ public class CSVSparkTransformServer extends SparkTransformServer { System.exit(1); } - RoutingDsl routingDsl = new RoutingDsl(); - - if (jsonPath != null) { String json = FileUtils.readFileToString(new File(jsonPath)); TransformProcess transformProcess = TransformProcess.fromJson(json); @@ -78,8 +79,26 @@ public class CSVSparkTransformServer extends SparkTransformServer { + "to /transformprocess"); } + //Set play secret key, if required + //http://www.playframework.com/documentation/latest/ApplicationSecret + String crypto = System.getProperty("play.crypto.secret"); + if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) { + byte[] newCrypto = new byte[1024]; - routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> { + new Random().nextBytes(newCrypto); + + String base64 = Base64.getEncoder().encodeToString(newCrypto); + System.setProperty("play.crypto.secret", base64); + } + + + server = Server.forRouter(Mode.PROD, port, this::createRouter); + } + + protected Router createRouter(BuiltInComponents b){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(b); + + routingDsl.GET("/transformprocess").routingTo(req -> { try { if (transform == null) return badRequest(); @@ -88,11 +107,11 @@ public class CSVSparkTransformServer extends SparkTransformServer { log.error("Error in GET /transformprocess",e); return internalServerError(e.getMessage()); } - }))); + }); - routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformprocess").routingTo(req -> { try { - TransformProcess transformProcess = TransformProcess.fromJson(getJsonText()); + TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req)); setCSVTransformProcess(transformProcess); log.info("Transform process initialized"); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); @@ -100,12 +119,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { log.error("Error in POST /transformprocess",e); return internalServerError(e.getMessage()); } - }))); + }); - routingDsl.POST("/transformincremental").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transformincremental").routingTo(req -> { + if (isSequence(req)) { try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType); @@ -115,7 +134,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class); + SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType); @@ -124,12 +143,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } - }))); + }); - routingDsl.POST("/transform").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transform").routingTo(req -> { + if (isSequence(req)) { try { - SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class)); + SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class)); if (batch == null) return badRequest(); return ok(objectMapper.writeValueAsString(batch)).as(contentType); @@ -139,7 +158,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - BatchCSVRecord input = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); BatchCSVRecord batch = transform(input); if (batch == null) return badRequest(); @@ -149,14 +168,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } + }); - - }))); - - routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transformincrementalarray").routingTo(req -> { + if (isSequence(req)) { try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType); @@ -166,7 +183,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(), SingleCSVRecord.class); + SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType); @@ -175,13 +192,12 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } + }); - }))); - - routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> { - if (isSequence()) { + routingDsl.POST("/transformarray").routingTo(req -> { + if (isSequence(req)) { try { - SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), SequenceBatchCSVRecord.class); + SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class); if (batchCSVRecord == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType); @@ -191,7 +207,7 @@ public class CSVSparkTransformServer extends SparkTransformServer { } } else { try { - BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(), BatchCSVRecord.class); + BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); if (batchCSVRecord == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType); @@ -200,10 +216,9 @@ public class CSVSparkTransformServer extends SparkTransformServer { return internalServerError(e.getMessage()); } } - }))); + }); - - server = Server.forRouter(routingDsl.build(), Mode.PROD, port); + return routingDsl.build(); } public static void main(String[] args) throws Exception { diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java deleted file mode 100644 index 6c4874b02..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/FunctionUtil.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import play.libs.F; -import play.mvc.Result; - -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * Utility methods for Routing - * - * @author Alex Black - */ -public class FunctionUtil { - - - public static F.Function0 function0(Supplier supplier) { - return supplier::get; - } - - public static F.Function function(Function function) { - return function::apply; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java index 29c1e1bd7..f8675f139 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java @@ -24,8 +24,11 @@ import org.apache.commons.io.FileUtils; import org.datavec.api.transform.TransformProcess; import org.datavec.image.transform.ImageTransformProcess; import org.datavec.spark.transform.model.*; +import play.BuiltInComponents; import play.Mode; +import play.libs.Files; import play.mvc.Http; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; @@ -33,6 +36,7 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import static play.mvc.Controller.request; import static play.mvc.Results.*; @@ -62,8 +66,6 @@ public class ImageSparkTransformServer extends SparkTransformServer { System.exit(1); } - RoutingDsl routingDsl = new RoutingDsl(); - if (jsonPath != null) { String json = FileUtils.readFileToString(new File(jsonPath)); ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json); @@ -73,7 +75,13 @@ public class ImageSparkTransformServer extends SparkTransformServer { + "to /transformprocess"); } - routingDsl.GET("/transformprocess").routeTo(FunctionUtil.function0((() -> { + server = Server.forRouter(Mode.PROD, port, this::createRouter); + } + + protected Router createRouter(BuiltInComponents builtInComponents){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); + + routingDsl.GET("/transformprocess").routingTo(req -> { try { if (transform == null) return badRequest(); @@ -83,11 +91,11 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformprocess").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformprocess").routingTo(req -> { try { - ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText()); + ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req)); setImageTransformProcess(transformProcess); log.info("Transform process initialized"); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); @@ -95,11 +103,11 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformincrementalarray").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformincrementalarray").routingTo(req -> { try { - SingleImageRecord record = objectMapper.readValue(getJsonText(), SingleImageRecord.class); + SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class); if (record == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); @@ -107,17 +115,17 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformincrementalimage").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformincrementalimage").routingTo(req -> { try { - Http.MultipartFormData body = request().body().asMultipartFormData(); - List files = body.getFiles(); - if (files.size() == 0 || files.get(0).getFile() == null) { + Http.MultipartFormData body = req.body().asMultipartFormData(); + List> files = body.getFiles(); + if (files.isEmpty() || files.get(0).getRef() == null ) { return badRequest(); } - File file = files.get(0).getFile(); + File file = files.get(0).getRef().path().toFile(); SingleImageRecord record = new SingleImageRecord(file.toURI()); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); @@ -125,11 +133,11 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformarray").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformarray").routingTo(req -> { try { - BatchImageRecord batch = objectMapper.readValue(getJsonText(), BatchImageRecord.class); + BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class); if (batch == null) return badRequest(); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); @@ -137,22 +145,22 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - routingDsl.POST("/transformimage").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/transformimage").routingTo(req -> { try { - Http.MultipartFormData body = request().body().asMultipartFormData(); - List files = body.getFiles(); + Http.MultipartFormData body = req.body().asMultipartFormData(); + List> files = body.getFiles(); if (files.size() == 0) { return badRequest(); } List records = new ArrayList<>(); - for (Http.MultipartFormData.FilePart filePart : files) { - File file = filePart.getFile(); + for (Http.MultipartFormData.FilePart filePart : files) { + Files.TemporaryFile file = filePart.getRef(); if (file != null) { - SingleImageRecord record = new SingleImageRecord(file.toURI()); + SingleImageRecord record = new SingleImageRecord(file.path().toUri()); records.add(record); } } @@ -164,9 +172,9 @@ public class ImageSparkTransformServer extends SparkTransformServer { e.printStackTrace(); return internalServerError(); } - }))); + }); - server = Server.forRouter(routingDsl.build(), Mode.PROD, port); + return routingDsl.build(); } @Override diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java index 2d4c92836..411872006 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/SparkTransformServer.java @@ -22,6 +22,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody; import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.service.DataVecTransformService; import org.nd4j.shade.jackson.databind.ObjectMapper; +import play.mvc.Http; import play.server.Server; import static play.mvc.Controller.request; @@ -50,25 +51,17 @@ public abstract class SparkTransformServer implements DataVecTransformService { server.stop(); } - protected boolean isSequence() { - return request().hasHeader(SEQUENCE_OR_NOT_HEADER) - && request().getHeader(SEQUENCE_OR_NOT_HEADER).toUpperCase() - .equals("TRUE"); + protected boolean isSequence(Http.Request request) { + return request.hasHeader(SEQUENCE_OR_NOT_HEADER) + && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true"); } - - protected String getHeaderValue(String value) { - if (request().hasHeader(value)) - return request().getHeader(value); - return null; - } - - protected String getJsonText() { - JsonNode tryJson = request().body().asJson(); + protected String getJsonText(Http.Request request) { + JsonNode tryJson = request.body().asJson(); if (tryJson != null) return tryJson.toString(); else - return request().body().asText(); + return request.body().asText(); } public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf new file mode 100644 index 000000000..28a4aa208 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf @@ -0,0 +1,350 @@ +# This is the main configuration file for the application. +# https://www.playframework.com/documentation/latest/ConfigFile +# ~~~~~ +# Play uses HOCON as its configuration file format. HOCON has a number +# of advantages over other config formats, but there are two things that +# can be used when modifying settings. +# +# You can include other configuration files in this main application.conf file: +#include "extra-config.conf" +# +# You can declare variables and substitute for them: +#mykey = ${some.value} +# +# And if an environment variable exists when there is no other subsitution, then +# HOCON will fall back to substituting environment variable: +#mykey = ${JAVA_HOME} + +## Akka +# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration +# https://www.playframework.com/documentation/latest/JavaAkka#Configuration +# ~~~~~ +# Play uses Akka internally and exposes Akka Streams and actors in Websockets and +# other streaming HTTP responses. +akka { + # "akka.log-config-on-start" is extraordinarly useful because it log the complete + # configuration at INFO level, including defaults and overrides, so it s worth + # putting at the very top. + # + # Put the following in your conf/logback.xml file: + # + # + # + # And then uncomment this line to debug the configuration. + # + #log-config-on-start = true +} + +## Modules +# https://www.playframework.com/documentation/latest/Modules +# ~~~~~ +# Control which modules are loaded when Play starts. Note that modules are +# the replacement for "GlobalSettings", which are deprecated in 2.5.x. +# Please see https://www.playframework.com/documentation/latest/GlobalSettings +# for more information. +# +# You can also extend Play functionality by using one of the publically available +# Play modules: https://playframework.com/documentation/latest/ModuleDirectory +play.modules { + # By default, Play will load any class called Module that is defined + # in the root package (the "app" directory), or you can define them + # explicitly below. + # If there are any built-in modules that you want to disable, you can list them here. + #enabled += my.application.Module + + # If there are any built-in modules that you want to disable, you can list them here. + #disabled += "" +} + +## Internationalisation +# https://www.playframework.com/documentation/latest/JavaI18N +# https://www.playframework.com/documentation/latest/ScalaI18N +# ~~~~~ +# Play comes with its own i18n settings, which allow the user's preferred language +# to map through to internal messages, or allow the language to be stored in a cookie. +play.i18n { + # The application languages + langs = [ "en" ] + + # Whether the language cookie should be secure or not + #langCookieSecure = true + + # Whether the HTTP only attribute of the cookie should be set to true + #langCookieHttpOnly = true +} + +## Play HTTP settings +# ~~~~~ +play.http { + ## Router + # https://www.playframework.com/documentation/latest/JavaRouting + # https://www.playframework.com/documentation/latest/ScalaRouting + # ~~~~~ + # Define the Router object to use for this application. + # This router will be looked up first when the application is starting up, + # so make sure this is the entry point. + # Furthermore, it's assumed your route file is named properly. + # So for an application router like `my.application.Router`, + # you may need to define a router file `conf/my.application.routes`. + # Default to Routes in the root package (aka "apps" folder) (and conf/routes) + #router = my.application.Router + + ## Action Creator + # https://www.playframework.com/documentation/latest/JavaActionCreator + # ~~~~~ + #actionCreator = null + + ## ErrorHandler + # https://www.playframework.com/documentation/latest/JavaRouting + # https://www.playframework.com/documentation/latest/ScalaRouting + # ~~~~~ + # If null, will attempt to load a class called ErrorHandler in the root package, + #errorHandler = null + + ## Filters + # https://www.playframework.com/documentation/latest/ScalaHttpFilters + # https://www.playframework.com/documentation/latest/JavaHttpFilters + # ~~~~~ + # Filters run code on every request. They can be used to perform + # common logic for all your actions, e.g. adding common headers. + # Defaults to "Filters" in the root package (aka "apps" folder) + # Alternatively you can explicitly register a class here. + #filters += my.application.Filters + + ## Session & Flash + # https://www.playframework.com/documentation/latest/JavaSessionFlash + # https://www.playframework.com/documentation/latest/ScalaSessionFlash + # ~~~~~ + session { + # Sets the cookie to be sent only over HTTPS. + #secure = true + + # Sets the cookie to be accessed only by the server. + #httpOnly = true + + # Sets the max-age field of the cookie to 5 minutes. + # NOTE: this only sets when the browser will discard the cookie. Play will consider any + # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout, + # you need to put a timestamp in the session and check it at regular intervals to possibly expire it. + #maxAge = 300 + + # Sets the domain on the session cookie. + #domain = "example.com" + } + + flash { + # Sets the cookie to be sent only over HTTPS. + #secure = true + + # Sets the cookie to be accessed only by the server. + #httpOnly = true + } +} + +## Netty Provider +# https://www.playframework.com/documentation/latest/SettingsNetty +# ~~~~~ +play.server.netty { + # Whether the Netty wire should be logged + #log.wire = true + + # If you run Play on Linux, you can use Netty's native socket transport + # for higher performance with less garbage. + #transport = "native" +} + +## WS (HTTP Client) +# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS +# ~~~~~ +# The HTTP client primarily used for REST APIs. The default client can be +# configured directly, but you can also create different client instances +# with customized settings. You must enable this by adding to build.sbt: +# +# libraryDependencies += ws // or javaWs if using java +# +play.ws { + # Sets HTTP requests not to follow 302 requests + #followRedirects = false + + # Sets the maximum number of open HTTP connections for the client. + #ahc.maxConnectionsTotal = 50 + + ## WS SSL + # https://www.playframework.com/documentation/latest/WsSSL + # ~~~~~ + ssl { + # Configuring HTTPS with Play WS does not require programming. You can + # set up both trustManager and keyManager for mutual authentication, and + # turn on JSSE debugging in development with a reload. + #debug.handshake = true + #trustManager = { + # stores = [ + # { type = "JKS", path = "exampletrust.jks" } + # ] + #} + } +} + +## Cache +# https://www.playframework.com/documentation/latest/JavaCache +# https://www.playframework.com/documentation/latest/ScalaCache +# ~~~~~ +# Play comes with an integrated cache API that can reduce the operational +# overhead of repeated requests. You must enable this by adding to build.sbt: +# +# libraryDependencies += cache +# +play.cache { + # If you want to bind several caches, you can bind the individually + #bindCaches = ["db-cache", "user-cache", "session-cache"] +} + +## Filters +# https://www.playframework.com/documentation/latest/Filters +# ~~~~~ +# There are a number of built-in filters that can be enabled and configured +# to give Play greater security. You must enable this by adding to build.sbt: +# +# libraryDependencies += filters +# +play.filters { + ## CORS filter configuration + # https://www.playframework.com/documentation/latest/CorsFilter + # ~~~~~ + # CORS is a protocol that allows web applications to make requests from the browser + # across different domains. + # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has + # dependencies on CORS settings. + cors { + # Filter paths by a whitelist of path prefixes + #pathPrefixes = ["/some/path", ...] + + # The allowed origins. If null, all origins are allowed. + #allowedOrigins = ["http://www.example.com"] + + # The allowed HTTP methods. If null, all methods are allowed + #allowedHttpMethods = ["GET", "POST"] + } + + ## CSRF Filter + # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter + # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter + # ~~~~~ + # Play supports multiple methods for verifying that a request is not a CSRF request. + # The primary mechanism is a CSRF token. This token gets placed either in the query string + # or body of every form submitted, and also gets placed in the users session. + # Play then verifies that both tokens are present and match. + csrf { + # Sets the cookie to be sent only over HTTPS + #cookie.secure = true + + # Defaults to CSRFErrorHandler in the root package. + #errorHandler = MyCSRFErrorHandler + } + + ## Security headers filter configuration + # https://www.playframework.com/documentation/latest/SecurityHeaders + # ~~~~~ + # Defines security headers that prevent XSS attacks. + # If enabled, then all options are set to the below configuration by default: + headers { + # The X-Frame-Options header. If null, the header is not set. + #frameOptions = "DENY" + + # The X-XSS-Protection header. If null, the header is not set. + #xssProtection = "1; mode=block" + + # The X-Content-Type-Options header. If null, the header is not set. + #contentTypeOptions = "nosniff" + + # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set. + #permittedCrossDomainPolicies = "master-only" + + # The Content-Security-Policy header. If null, the header is not set. + #contentSecurityPolicy = "default-src 'self'" + } + + ## Allowed hosts filter configuration + # https://www.playframework.com/documentation/latest/AllowedHostsFilter + # ~~~~~ + # Play provides a filter that lets you configure which hosts can access your application. + # This is useful to prevent cache poisoning attacks. + hosts { + # Allow requests to example.com, its subdomains, and localhost:9000. + #allowed = [".example.com", "localhost:9000"] + } +} + +## Evolutions +# https://www.playframework.com/documentation/latest/Evolutions +# ~~~~~ +# Evolutions allows database scripts to be automatically run on startup in dev mode +# for database migrations. You must enable this by adding to build.sbt: +# +# libraryDependencies += evolutions +# +play.evolutions { + # You can disable evolutions for a specific datasource if necessary + #db.default.enabled = false +} + +## Database Connection Pool +# https://www.playframework.com/documentation/latest/SettingsJDBC +# ~~~~~ +# Play doesn't require a JDBC database to run, but you can easily enable one. +# +# libraryDependencies += jdbc +# +play.db { + # The combination of these two settings results in "db.default" as the + # default JDBC pool: + #config = "db" + #default = "default" + + # Play uses HikariCP as the default connection pool. You can override + # settings by changing the prototype: + prototype { + # Sets a fixed JDBC connection pool size of 50 + #hikaricp.minimumIdle = 50 + #hikaricp.maximumPoolSize = 50 + } +} + +## JDBC Datasource +# https://www.playframework.com/documentation/latest/JavaDatabase +# https://www.playframework.com/documentation/latest/ScalaDatabase +# ~~~~~ +# Once JDBC datasource is set up, you can work with several different +# database options: +# +# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick +# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA +# EBean: https://playframework.com/documentation/latest/JavaEbean +# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm +# +db { + # You can declare as many datasources as you want. + # By convention, the default datasource is named `default` + + # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database + default.driver = org.h2.Driver + default.url = "jdbc:h2:mem:play" + #default.username = sa + #default.password = "" + + # You can expose this datasource via JNDI if needed (Useful for JPA) + default.jndiName=DefaultDS + + # You can turn on SQL logging for any datasource + # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements + #default.logSql=true +} + +jpa.default=defaultPersistenceUnit + + +#Increase default maximum post length - used for remote listener functionality +#Can get response 413 with larger networks without setting this +# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead +#parsers.text.maxLength=10M +play.http.parser.maxMemoryBuffer=10M diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index 2f8a046ec..05c505cac 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -28,61 +28,11 @@ datavec-spark_2.11 - - 2.1.0 - 2 - 2.11.12 2.11 - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-source - generate-sources - - add-source - - - - src/main/spark-${spark.major.version} - - - - - - - - - - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${jackson.version} - - - com.fasterxml.jackson.module - jackson-module-scala_2.11 - ${jackson.version} - - - - org.scala-lang @@ -95,42 +45,13 @@ ${scala.version} - - org.codehaus.jackson - jackson-core-asl - ${jackson-asl.version} - - - org.codehaus.jackson - jackson-mapper-asl - ${jackson-asl.version} - org.apache.spark spark-sql_2.11 ${spark.version} + provided - - com.google.guava - guava - ${guava.version} - - - com.google.inject - guice - ${guice.version} - - - com.google.protobuf - protobuf-java - ${google.protobuf.version} - - - commons-codec - commons-codec - ${commons-codec.version} - commons-collections commons-collections @@ -141,96 +62,16 @@ commons-io ${commons-io.version} - - commons-lang - commons-lang - ${commons-lang.version} - - - commons-net - commons-net - ${commons-net.version} - - - com.sun.xml.bind - jaxb-core - ${jaxb.version} - - - com.sun.xml.bind - jaxb-impl - ${jaxb.version} - - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-remote_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - - - io.netty - netty - ${netty.version} - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - javax.servlet - javax.servlet-api - ${servlet.version} - - - org.apache.commons - commons-compress - ${commons-compress.version} - - - org.apache.commons - commons-lang3 - ${commons-lang3.version} - org.apache.commons commons-math3 ${commons-math3.version} - - org.apache.curator - curator-recipes - ${curator.version} - org.slf4j slf4j-api ${slf4j.version} - - com.typesafe - config - ${typesafe.config.version} - org.apache.spark spark-core_2.11 @@ -241,14 +82,6 @@ com.google.code.findbugs jsr305 - - org.slf4j - slf4j-log4j12 - - - log4j - log4j - @@ -281,13 +114,6 @@ test - - org.nd4j - nd4j-native - ${nd4j.version} - test - - org.datavec datavec-local diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java index fbe8a63d4..5d0bff7f3 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/DataFrames.java @@ -21,10 +21,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; -import org.apache.spark.sql.Column; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.functions; +import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -46,7 +43,6 @@ import java.util.List; import static org.apache.spark.sql.functions.avg; import static org.apache.spark.sql.functions.col; -import static org.datavec.spark.transform.DataRowsFacade.dataRows; /** @@ -71,7 +67,7 @@ public class DataFrames { * deviation for * @return the column that represents the standard deviation */ - public static Column std(DataRowsFacade dataFrame, String columnName) { + public static Column std(Dataset dataFrame, String columnName) { return functions.sqrt(var(dataFrame, columnName)); } @@ -85,8 +81,8 @@ public class DataFrames { * deviation for * @return the column that represents the standard deviation */ - public static Column var(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(functions.variance(columnName)).col(columnName); + public static Column var(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(functions.variance(columnName)).col(columnName); } /** @@ -97,8 +93,8 @@ public class DataFrames { * @param columnName the name of the column to get the min for * @return the column that represents the min */ - public static Column min(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(functions.min(columnName)).col(columnName); + public static Column min(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(functions.min(columnName)).col(columnName); } /** @@ -110,8 +106,8 @@ public class DataFrames { * to get the max for * @return the column that represents the max */ - public static Column max(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(functions.max(columnName)).col(columnName); + public static Column max(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(functions.max(columnName)).col(columnName); } /** @@ -122,8 +118,8 @@ public class DataFrames { * @param columnName the name of the column to get the mean for * @return the column that represents the mean */ - public static Column mean(DataRowsFacade dataFrame, String columnName) { - return dataFrame.get().groupBy(columnName).agg(avg(columnName)).col(columnName); + public static Column mean(Dataset dataFrame, String columnName) { + return dataFrame.groupBy(columnName).agg(avg(columnName)).col(columnName); } /** @@ -166,7 +162,7 @@ public class DataFrames { * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position * of this record in the original time series.
* These two columns are required if the data is to be converted back into a sequence at a later point, for example - * using {@link #toRecordsSequence(DataRowsFacade)} + * using {@link #toRecordsSequence(Dataset)} * * @param schema Schema to convert * @return StructType for the schema @@ -250,9 +246,9 @@ public class DataFrames { * @param dataFrame the dataframe to convert * @return the converted schema and rdd of writables */ - public static Pair>> toRecords(DataRowsFacade dataFrame) { - Schema schema = fromStructType(dataFrame.get().schema()); - return new Pair<>(schema, dataFrame.get().javaRDD().map(new ToRecord(schema))); + public static Pair>> toRecords(Dataset dataFrame) { + Schema schema = fromStructType(dataFrame.schema()); + return new Pair<>(schema, dataFrame.javaRDD().map(new ToRecord(schema))); } /** @@ -267,11 +263,11 @@ public class DataFrames { * @param dataFrame Data frame to convert * @return Data in sequence (i.e., {@code List>} form */ - public static Pair>>> toRecordsSequence(DataRowsFacade dataFrame) { + public static Pair>>> toRecordsSequence(Dataset dataFrame) { //Need to convert from flattened to sequence data... //First: Group by the Sequence UUID (first column) - JavaPairRDD> grouped = dataFrame.get().javaRDD().groupBy(new Function() { + JavaPairRDD> grouped = dataFrame.javaRDD().groupBy(new Function() { @Override public String call(Row row) throws Exception { return row.getString(0); @@ -279,7 +275,7 @@ public class DataFrames { }); - Schema schema = fromStructType(dataFrame.get().schema()); + Schema schema = fromStructType(dataFrame.schema()); //Group by sequence UUID, and sort each row within the sequences using the time step index Function, List>> createCombiner = new DataFrameToSequenceCreateCombiner(schema); //Function to create the initial combiner @@ -318,11 +314,11 @@ public class DataFrames { * @param data the data to convert * @return the dataframe object */ - public static DataRowsFacade toDataFrame(Schema schema, JavaRDD> data) { + public static Dataset toDataFrame(Schema schema, JavaRDD> data) { JavaSparkContext sc = new JavaSparkContext(data.context()); SQLContext sqlContext = new SQLContext(sc); JavaRDD rows = data.map(new ToRow(schema)); - return dataRows(sqlContext.createDataFrame(rows, fromSchema(schema))); + return sqlContext.createDataFrame(rows, fromSchema(schema)); } @@ -333,18 +329,18 @@ public class DataFrames { * - Column 1: Sequence index (name: {@link #SEQUENCE_INDEX_COLUMN} - an index (integer, starting at 0) for the position * of this record in the original time series.
* These two columns are required if the data is to be converted back into a sequence at a later point, for example - * using {@link #toRecordsSequence(DataRowsFacade)} + * using {@link #toRecordsSequence(Dataset)} * * @param schema Schema for the data * @param data Sequence data to convert to a DataFrame * @return The dataframe object */ - public static DataRowsFacade toDataFrameSequence(Schema schema, JavaRDD>> data) { + public static Dataset toDataFrameSequence(Schema schema, JavaRDD>> data) { JavaSparkContext sc = new JavaSparkContext(data.context()); SQLContext sqlContext = new SQLContext(sc); JavaRDD rows = data.flatMap(new SequenceToRows(schema)); - return dataRows(sqlContext.createDataFrame(rows, fromSchemaSequence(schema))); + return sqlContext.createDataFrame(rows, fromSchemaSequence(schema)); } /** diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java index 68efd1888..cacea101d 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/Normalization.java @@ -19,14 +19,13 @@ package org.datavec.spark.transform; import org.apache.commons.collections.map.ListOrderedMap; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Writable; import java.util.*; -import static org.datavec.spark.transform.DataRowsFacade.dataRows; - /** * Simple dataframe based normalization. @@ -46,7 +45,7 @@ public class Normalization { * @return a zero mean unit variance centered * rdd */ - public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame) { + public static Dataset zeromeanUnitVariance(Dataset frame) { return zeromeanUnitVariance(frame, Collections.emptyList()); } @@ -71,7 +70,7 @@ public class Normalization { * @param max the maximum value * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max) { + public static Dataset normalize(Dataset dataFrame, double min, double max) { return normalize(dataFrame, min, max, Collections.emptyList()); } @@ -86,7 +85,7 @@ public class Normalization { */ public static JavaRDD> normalize(Schema schema, JavaRDD> data, double min, double max) { - DataRowsFacade frame = DataFrames.toDataFrame(schema, data); + Dataset frame = DataFrames.toDataFrame(schema, data); return DataFrames.toRecords(normalize(frame, min, max, Collections.emptyList())).getSecond(); } @@ -97,7 +96,7 @@ public class Normalization { * @param dataFrame the dataframe to scale * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame) { + public static Dataset normalize(Dataset dataFrame) { return normalize(dataFrame, 0, 1, Collections.emptyList()); } @@ -120,8 +119,8 @@ public class Normalization { * @return a zero mean unit variance centered * rdd */ - public static DataRowsFacade zeromeanUnitVariance(DataRowsFacade frame, List skipColumns) { - List columnsList = DataFrames.toList(frame.get().columns()); + public static Dataset zeromeanUnitVariance(Dataset frame, List skipColumns) { + List columnsList = DataFrames.toList(frame.columns()); columnsList.removeAll(skipColumns); String[] columnNames = DataFrames.toArray(columnsList); //first row is std second row is mean, each column in a row is for a particular column @@ -133,7 +132,7 @@ public class Normalization { if (std == 0.0) std = 1; //All same value -> (x-x)/1 = 0 - frame = dataRows(frame.get().withColumn(columnName, frame.get().col(columnName).minus(mean).divide(std))); + frame = frame.withColumn(columnName, frame.col(columnName).minus(mean).divide(std)); } @@ -152,7 +151,7 @@ public class Normalization { */ public static JavaRDD> zeromeanUnitVariance(Schema schema, JavaRDD> data, List skipColumns) { - DataRowsFacade frame = DataFrames.toDataFrame(schema, data); + Dataset frame = DataFrames.toDataFrame(schema, data); return DataFrames.toRecords(zeromeanUnitVariance(frame, skipColumns)).getSecond(); } @@ -178,7 +177,7 @@ public class Normalization { */ public static JavaRDD>> zeroMeanUnitVarianceSequence(Schema schema, JavaRDD>> sequence, List excludeColumns) { - DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, sequence); + Dataset frame = DataFrames.toDataFrameSequence(schema, sequence); if (excludeColumns == null) excludeColumns = Arrays.asList(DataFrames.SEQUENCE_UUID_COLUMN, DataFrames.SEQUENCE_INDEX_COLUMN); else { @@ -196,7 +195,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List minMaxColumns(DataRowsFacade data, List columns) { + public static List minMaxColumns(Dataset data, List columns) { String[] arr = new String[columns.size()]; for (int i = 0; i < arr.length; i++) arr[i] = columns.get(i); @@ -210,7 +209,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List minMaxColumns(DataRowsFacade data, String... columns) { + public static List minMaxColumns(Dataset data, String... columns) { return aggregate(data, columns, new String[] {"min", "max"}); } @@ -221,7 +220,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List stdDevMeanColumns(DataRowsFacade data, List columns) { + public static List stdDevMeanColumns(Dataset data, List columns) { String[] arr = new String[columns.size()]; for (int i = 0; i < arr.length; i++) arr[i] = columns.get(i); @@ -237,7 +236,7 @@ public class Normalization { * @param columns the columns to get the * @return */ - public static List stdDevMeanColumns(DataRowsFacade data, String... columns) { + public static List stdDevMeanColumns(Dataset data, String... columns) { return aggregate(data, columns, new String[] {"stddev", "mean"}); } @@ -251,7 +250,7 @@ public class Normalization { * Each row will be a function with the desired columnar output * in the order in which the columns were specified. */ - public static List aggregate(DataRowsFacade data, String[] columns, String[] functions) { + public static List aggregate(Dataset data, String[] columns, String[] functions) { String[] rest = new String[columns.length - 1]; System.arraycopy(columns, 1, rest, 0, rest.length); List rows = new ArrayList<>(); @@ -262,8 +261,8 @@ public class Normalization { } //compute the aggregation based on the operation - DataRowsFacade aggregated = dataRows(data.get().agg(expressions)); - String[] columns2 = aggregated.get().columns(); + Dataset aggregated = data.agg(expressions); + String[] columns2 = aggregated.columns(); //strip out the op name and parentheses from the columns Map opReplace = new TreeMap<>(); for (String s : columns2) { @@ -278,20 +277,20 @@ public class Normalization { //get rid of the operation name in the column - DataRowsFacade rearranged = null; + Dataset rearranged = null; for (Map.Entry entries : opReplace.entrySet()) { //first column if (rearranged == null) { - rearranged = dataRows(aggregated.get().withColumnRenamed(entries.getKey(), entries.getValue())); + rearranged = aggregated.withColumnRenamed(entries.getKey(), entries.getValue()); } //rearranged is just a copy of aggregated at this point else - rearranged = dataRows(rearranged.get().withColumnRenamed(entries.getKey(), entries.getValue())); + rearranged = rearranged.withColumnRenamed(entries.getKey(), entries.getValue()); } - rearranged = dataRows(rearranged.get().select(DataFrames.toColumns(columns))); + rearranged = rearranged.select(DataFrames.toColumns(columns)); //op - rows.addAll(rearranged.get().collectAsList()); + rows.addAll(rearranged.collectAsList()); } @@ -307,8 +306,8 @@ public class Normalization { * @param max the maximum value * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame, double min, double max, List skipColumns) { - List columnsList = DataFrames.toList(dataFrame.get().columns()); + public static Dataset normalize(Dataset dataFrame, double min, double max, List skipColumns) { + List columnsList = DataFrames.toList(dataFrame.columns()); columnsList.removeAll(skipColumns); String[] columnNames = DataFrames.toArray(columnsList); //first row is min second row is max, each column in a row is for a particular column @@ -321,8 +320,8 @@ public class Normalization { if (maxSubMin == 0) maxSubMin = 1; - Column newCol = dataFrame.get().col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min); - dataFrame = dataRows(dataFrame.get().withColumn(columnName, newCol)); + Column newCol = dataFrame.col(columnName).minus(dMin).divide(maxSubMin).multiply(max - min).plus(min); + dataFrame = dataFrame.withColumn(columnName, newCol); } @@ -340,7 +339,7 @@ public class Normalization { */ public static JavaRDD> normalize(Schema schema, JavaRDD> data, double min, double max, List skipColumns) { - DataRowsFacade frame = DataFrames.toDataFrame(schema, data); + Dataset frame = DataFrames.toDataFrame(schema, data); return DataFrames.toRecords(normalize(frame, min, max, skipColumns)).getSecond(); } @@ -387,7 +386,7 @@ public class Normalization { excludeColumns.add(DataFrames.SEQUENCE_UUID_COLUMN); excludeColumns.add(DataFrames.SEQUENCE_INDEX_COLUMN); } - DataRowsFacade frame = DataFrames.toDataFrameSequence(schema, data); + Dataset frame = DataFrames.toDataFrameSequence(schema, data); return DataFrames.toRecordsSequence(normalize(frame, min, max, excludeColumns)).getSecond(); } @@ -398,7 +397,7 @@ public class Normalization { * @param dataFrame the dataframe to scale * @return the normalized dataframe per column */ - public static DataRowsFacade normalize(DataRowsFacade dataFrame, List skipColumns) { + public static Dataset normalize(Dataset dataFrame, List skipColumns) { return normalize(dataFrame, 0, 1, skipColumns); } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java index 6b7ff203f..5052491bb 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java @@ -16,9 +16,10 @@ package org.datavec.spark.transform.analysis; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Iterator; import java.util.List; /** @@ -27,10 +28,11 @@ import java.util.List; * * @author Alex Black */ -public class SequenceFlatMapFunction extends BaseFlatMapFunctionAdaptee>, List> { +public class SequenceFlatMapFunction implements FlatMapFunction>, List> { - public SequenceFlatMapFunction() { - super(new SequenceFlatMapFunctionAdapter()); + @Override + public Iterator> call(List> collections) throws Exception { + return collections.iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java deleted file mode 100644 index 6b25fb826..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunctionAdapter.java +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.analysis; - -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.List; - -/** - * SequenceFlatMapFunction: very simple function used to flatten a sequence - * Typically used only internally for certain analysis operations - * - * @author Alex Black - */ -public class SequenceFlatMapFunctionAdapter implements FlatMapFunctionAdapter>, List> { - @Override - public Iterable> call(List> collections) throws Exception { - return collections; - } - -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java index 52bf924be..6e501f560 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java @@ -16,11 +16,14 @@ package org.datavec.spark.transform.join; +import org.nd4j.shade.guava.collect.Iterables; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.join.Join; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Iterator; import java.util.List; /** @@ -28,10 +31,89 @@ import java.util.List; * * @author Alex Black */ -public class ExecuteJoinFromCoGroupFlatMapFunction extends - BaseFlatMapFunctionAdaptee, Tuple2>, Iterable>>>, List> { +public class ExecuteJoinFromCoGroupFlatMapFunction implements FlatMapFunction, Tuple2>, Iterable>>>, List> { + + private final Join join; public ExecuteJoinFromCoGroupFlatMapFunction(Join join) { - super(new ExecuteJoinFromCoGroupFlatMapFunctionAdapter(join)); + this.join = join; + } + + @Override + public Iterator> call( + Tuple2, Tuple2>, Iterable>>> t2) + throws Exception { + + Iterable> leftList = t2._2()._1(); + Iterable> rightList = t2._2()._2(); + + List> ret = new ArrayList<>(); + Join.JoinType jt = join.getJoinType(); + switch (jt) { + case Inner: + //Return records where key columns appear in BOTH + //So if no values from left OR right: no return values + for (List jvl : leftList) { + for (List jvr : rightList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + break; + case LeftOuter: + //Return all records from left, even if no corresponding right value (NullWritable in that case) + for (List jvl : leftList) { + if (Iterables.size(rightList) == 0) { + List joined = join.joinExamples(jvl, null); + ret.add(joined); + } else { + for (List jvr : rightList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + } + break; + case RightOuter: + //Return all records from right, even if no corresponding left value (NullWritable in that case) + for (List jvr : rightList) { + if (Iterables.size(leftList) == 0) { + List joined = join.joinExamples(null, jvr); + ret.add(joined); + } else { + for (List jvl : leftList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + } + break; + case FullOuter: + //Return all records, even if no corresponding left/right value (NullWritable in that case) + if (Iterables.size(leftList) == 0) { + //Only right values + for (List jvr : rightList) { + List joined = join.joinExamples(null, jvr); + ret.add(joined); + } + } else if (Iterables.size(rightList) == 0) { + //Only left values + for (List jvl : leftList) { + List joined = join.joinExamples(jvl, null); + ret.add(joined); + } + } else { + //Records from both left and right + for (List jvl : leftList) { + for (List jvr : rightList) { + List joined = join.joinExamples(jvl, jvr); + ret.add(joined); + } + } + } + break; + } + + return ret.iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java deleted file mode 100644 index dedff46d0..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java +++ /dev/null @@ -1,119 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.join; - -import com.google.common.collect.Iterables; -import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import scala.Tuple2; - -import java.util.ArrayList; -import java.util.List; - -/** - * Execute a join - * - * @author Alex Black - */ -public class ExecuteJoinFromCoGroupFlatMapFunctionAdapter implements - FlatMapFunctionAdapter, Tuple2>, Iterable>>>, List> { - - private final Join join; - - public ExecuteJoinFromCoGroupFlatMapFunctionAdapter(Join join) { - this.join = join; - } - - @Override - public Iterable> call( - Tuple2, Tuple2>, Iterable>>> t2) - throws Exception { - - Iterable> leftList = t2._2()._1(); - Iterable> rightList = t2._2()._2(); - - List> ret = new ArrayList<>(); - Join.JoinType jt = join.getJoinType(); - switch (jt) { - case Inner: - //Return records where key columns appear in BOTH - //So if no values from left OR right: no return values - for (List jvl : leftList) { - for (List jvr : rightList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - break; - case LeftOuter: - //Return all records from left, even if no corresponding right value (NullWritable in that case) - for (List jvl : leftList) { - if (Iterables.size(rightList) == 0) { - List joined = join.joinExamples(jvl, null); - ret.add(joined); - } else { - for (List jvr : rightList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - } - break; - case RightOuter: - //Return all records from right, even if no corresponding left value (NullWritable in that case) - for (List jvr : rightList) { - if (Iterables.size(leftList) == 0) { - List joined = join.joinExamples(null, jvr); - ret.add(joined); - } else { - for (List jvl : leftList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - } - break; - case FullOuter: - //Return all records, even if no corresponding left/right value (NullWritable in that case) - if (Iterables.size(leftList) == 0) { - //Only right values - for (List jvr : rightList) { - List joined = join.joinExamples(null, jvr); - ret.add(joined); - } - } else if (Iterables.size(rightList) == 0) { - //Only left values - for (List jvl : leftList) { - List joined = join.joinExamples(jvl, null); - ret.add(joined); - } - } else { - //Records from both left and right - for (List jvl : leftList) { - for (List jvr : rightList) { - List joined = join.joinExamples(jvl, jvr); - ret.add(joined); - } - } - } - break; - } - - return ret; - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java index 6e206a657..d4ede4808 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java @@ -16,10 +16,12 @@ package org.datavec.spark.transform.join; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.join.Join; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Collections; +import java.util.Iterator; import java.util.List; /** @@ -29,10 +31,43 @@ import java.util.List; * * @author Alex Black */ -public class FilterAndFlattenJoinedValues extends BaseFlatMapFunctionAdaptee> { +public class FilterAndFlattenJoinedValues implements FlatMapFunction> { + + private final Join.JoinType joinType; public FilterAndFlattenJoinedValues(Join.JoinType joinType) { - super(new FilterAndFlattenJoinedValuesAdapter(joinType)); + this.joinType = joinType; + } + + @Override + public Iterator> call(JoinedValue joinedValue) throws Exception { + boolean keep; + switch (joinType) { + case Inner: + //Only keep joined values where we have both left and right + keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight(); + break; + case LeftOuter: + //Keep all values where left is not missing/null + keep = joinedValue.isHaveLeft(); + break; + case RightOuter: + //Keep all values where right is not missing/null + keep = joinedValue.isHaveRight(); + break; + case FullOuter: + //Keep all values + keep = true; + break; + default: + throw new RuntimeException("Unknown/not implemented join type: " + joinType); + } + + if (keep) { + return Collections.singletonList(joinedValue.getValues()).iterator(); + } else { + return Collections.emptyIterator(); + } } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java deleted file mode 100644 index 3333276b1..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValuesAdapter.java +++ /dev/null @@ -1,71 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.join; - -import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Collections; -import java.util.List; - -/** - * Doing two things here: - * (a) filter out any unnecessary values, and - * (b) extract the List values from the JoinedValue - * - * @author Alex Black - */ -public class FilterAndFlattenJoinedValuesAdapter implements FlatMapFunctionAdapter> { - - private final Join.JoinType joinType; - - public FilterAndFlattenJoinedValuesAdapter(Join.JoinType joinType) { - this.joinType = joinType; - } - - @Override - public Iterable> call(JoinedValue joinedValue) throws Exception { - boolean keep; - switch (joinType) { - case Inner: - //Only keep joined values where we have both left and right - keep = joinedValue.isHaveLeft() && joinedValue.isHaveRight(); - break; - case LeftOuter: - //Keep all values where left is not missing/null - keep = joinedValue.isHaveLeft(); - break; - case RightOuter: - //Keep all values where right is not missing/null - keep = joinedValue.isHaveRight(); - break; - case FullOuter: - //Keep all values - keep = true; - break; - default: - throw new RuntimeException("Unknown/not implemented join type: " + joinType); - } - - if (keep) { - return Collections.singletonList(joinedValue.getValues()); - } else { - return Collections.emptyList(); - } - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java index 28c91c84b..639e43836 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java @@ -16,21 +16,69 @@ package org.datavec.spark.transform.sparkfunction; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; +import org.apache.spark.sql.types.StructType; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.datavec.spark.transform.DataFrames; -import java.util.List; +import java.util.*; /** * Convert a record to a row * @author Adam Gibson */ -public class SequenceToRows extends BaseFlatMapFunctionAdaptee>, Row> { +public class SequenceToRows implements FlatMapFunction>, Row> { + + private Schema schema; + private StructType structType; public SequenceToRows(Schema schema) { - super(new SequenceToRowsAdapter(schema)); + this.schema = schema; + structType = DataFrames.fromSchemaSequence(schema); } + + @Override + public Iterator call(List> sequence) throws Exception { + if (sequence.size() == 0) + return Collections.emptyIterator(); + + String sequenceUUID = UUID.randomUUID().toString(); + + List out = new ArrayList<>(sequence.size()); + + int stepCount = 0; + for (List step : sequence) { + Object[] values = new Object[step.size() + 2]; + values[0] = sequenceUUID; + values[1] = stepCount++; + for (int i = 0; i < step.size(); i++) { + switch (schema.getColumnTypes().get(i)) { + case Double: + values[i + 2] = step.get(i).toDouble(); + break; + case Integer: + values[i + 2] = step.get(i).toInt(); + break; + case Long: + values[i + 2] = step.get(i).toLong(); + break; + case Float: + values[i + 2] = step.get(i).toFloat(); + break; + default: + throw new IllegalStateException( + "This api should not be used with strings , binary data or ndarrays. This is only for columnar data"); + } + } + + Row row = new GenericRowWithSchema(values, structType); + out.add(row); + } + + return out.iterator(); + } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java deleted file mode 100644 index 2ca2f32ae..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRowsAdapter.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.sparkfunction; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; -import org.apache.spark.sql.types.StructType; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.DataFrames; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.UUID; - -/** - * Convert a record to a row - * @author Adam Gibson - */ -public class SequenceToRowsAdapter implements FlatMapFunctionAdapter>, Row> { - - private Schema schema; - private StructType structType; - - public SequenceToRowsAdapter(Schema schema) { - this.schema = schema; - structType = DataFrames.fromSchemaSequence(schema); - } - - - @Override - public Iterable call(List> sequence) throws Exception { - if (sequence.size() == 0) - return Collections.emptyList(); - - String sequenceUUID = UUID.randomUUID().toString(); - - List out = new ArrayList<>(sequence.size()); - - int stepCount = 0; - for (List step : sequence) { - Object[] values = new Object[step.size() + 2]; - values[0] = sequenceUUID; - values[1] = stepCount++; - for (int i = 0; i < step.size(); i++) { - switch (schema.getColumnTypes().get(i)) { - case Double: - values[i + 2] = step.get(i).toDouble(); - break; - case Integer: - values[i + 2] = step.get(i).toInt(); - break; - case Long: - values[i + 2] = step.get(i).toLong(); - break; - case Float: - values[i + 2] = step.get(i).toFloat(); - break; - default: - throw new IllegalStateException( - "This api should not be used with strings , binary data or ndarrays. This is only for columnar data"); - } - } - - Row row = new GenericRowWithSchema(values, structType); - out.add(row); - } - - return out; - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java index 41981a736..1a8782dfb 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java @@ -16,19 +16,27 @@ package org.datavec.spark.transform.transform; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.sequence.SequenceSplit; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Iterator; import java.util.List; /** * Created by Alex on 17/03/2016. */ -public class SequenceSplitFunction extends BaseFlatMapFunctionAdaptee>, List>> { +public class SequenceSplitFunction implements FlatMapFunction>, List>> { + + private final SequenceSplit split; public SequenceSplitFunction(SequenceSplit split) { - super(new SequenceSplitFunctionAdapter(split)); + this.split = split; + } + + @Override + public Iterator>> call(List> collections) throws Exception { + return split.split(collections).iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java deleted file mode 100644 index 5bde7ee62..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunctionAdapter.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.transform; - -import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.List; - -/** - * Created by Alex on 17/03/2016. - */ -public class SequenceSplitFunctionAdapter - implements FlatMapFunctionAdapter>, List>> { - - private final SequenceSplit split; - - public SequenceSplitFunctionAdapter(SequenceSplit split) { - this.split = split; - } - - @Override - public Iterable>> call(List> collections) throws Exception { - return split.split(collections); - } -} diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java index ffe3f80c6..81f07b1f4 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java +++ b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java @@ -16,19 +16,32 @@ package org.datavec.spark.transform.transform; +import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.TransformProcess; import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import java.util.Collections; +import java.util.Iterator; import java.util.List; /** * Spark function for executing a transform process */ -public class SparkTransformProcessFunction extends BaseFlatMapFunctionAdaptee, List> { +public class SparkTransformProcessFunction implements FlatMapFunction, List> { + + private final TransformProcess transformProcess; public SparkTransformProcessFunction(TransformProcess transformProcess) { - super(new SparkTransformProcessFunctionAdapter(transformProcess)); + this.transformProcess = transformProcess; + } + + @Override + public Iterator> call(List v1) throws Exception { + List newList = transformProcess.execute(v1); + if (newList == null) + return Collections.emptyIterator(); //Example was filtered out + else + return Collections.singletonList(newList).iterator(); } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java b/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java deleted file mode 100644 index 7b1766cc2..000000000 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunctionAdapter.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform.transform; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Collections; -import java.util.List; - -/** - * Spark function for executing a transform process - */ -public class SparkTransformProcessFunctionAdapter implements FlatMapFunctionAdapter, List> { - - private final TransformProcess transformProcess; - - public SparkTransformProcessFunctionAdapter(TransformProcess transformProcess) { - this.transformProcess = transformProcess; - } - - @Override - public Iterable> call(List v1) throws Exception { - List newList = transformProcess.execute(v1); - if (newList == null) - return Collections.emptyList(); //Example was filtered out - else - return Collections.singletonList(newList); - } -} diff --git a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java b/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java deleted file mode 100644 index af600a14d..000000000 --- a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.api.java.function.FlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -/** - * FlatMapFunction adapter to - * hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to FlatMapFunction - * - */ -public class BaseFlatMapFunctionAdaptee implements FlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterable call(K k) throws Exception { - return adapter.call(k); - } -} diff --git a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java b/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java deleted file mode 100644 index 0ad7c55bd..000000000 --- a/datavec/datavec-spark/src/main/spark-1/org/datavec/spark/transform/DataRowsFacade.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.sql.DataFrame; - -/** - * Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DataFrame / Dataset - * - */ -public class DataRowsFacade { - - private final DataFrame df; - - private DataRowsFacade(DataFrame df) { - this.df = df; - } - - public static DataRowsFacade dataRows(DataFrame df) { - return new DataRowsFacade(df); - } - - public DataFrame get() { - return df; - } -} diff --git a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java b/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java deleted file mode 100644 index f30e5a222..000000000 --- a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/BaseFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.api.java.function.FlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Iterator; - -/** - * FlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to FlatMapFunction - * - */ -public class BaseFlatMapFunctionAdaptee implements FlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterator call(K k) throws Exception { - return adapter.call(k).iterator(); - } -} diff --git a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java b/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java deleted file mode 100644 index 9958a622e..000000000 --- a/datavec/datavec-spark/src/main/spark-2/org/datavec/spark/transform/DataRowsFacade.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.spark.transform; - -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; - -/** - * Dataframe facade to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DataFrame / Dataset - * - */ -public class DataRowsFacade { - - private final Dataset df; - - private DataRowsFacade(Dataset df) { - this.df = df; - } - - public static DataRowsFacade dataRows(Dataset df) { - return new DataRowsFacade(df); - } - - public Dataset get() { - return df; - } -} diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java index 49a3946e2..8f0247568 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java @@ -16,7 +16,7 @@ package org.datavec.spark.storage; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.datavec.api.writable.*; diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java index 5b6ff6342..a19725a2a 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/DataFramesTests.java @@ -19,6 +19,8 @@ package org.datavec.spark.transform; import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Column; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.DoubleWritable; @@ -46,9 +48,9 @@ public class DataFramesTests extends BaseSparkTest { for (int i = 0; i < numColumns; i++) builder.addColumnDouble(String.valueOf(i)); Schema schema = builder.build(); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records)); - dataFrame.get().show(); - dataFrame.get().describe(DataFrames.toArray(schema.getColumnNames())).show(); + Dataset dataFrame = DataFrames.toDataFrame(schema, sc.parallelize(records)); + dataFrame.show(); + dataFrame.describe(DataFrames.toArray(schema.getColumnNames())).show(); // System.out.println(Normalization.minMaxColumns(dataFrame,schema.getColumnNames())); // System.out.println(Normalization.stdDevMeanColumns(dataFrame,schema.getColumnNames())); @@ -77,12 +79,12 @@ public class DataFramesTests extends BaseSparkTest { assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema))); assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect()); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); - dataFrame.get().show(); + Dataset dataFrame = DataFrames.toDataFrame(schema, rdd); + dataFrame.show(); Column mean = DataFrames.mean(dataFrame, "0"); Column std = DataFrames.std(dataFrame, "0"); - dataFrame.get().withColumn("0", dataFrame.get().col("0").minus(mean)).show(); - dataFrame.get().withColumn("0", dataFrame.get().col("0").divide(std)).show(); + dataFrame.withColumn("0", dataFrame.col("0").minus(mean)).show(); + dataFrame.withColumn("0", dataFrame.col("0").divide(std)).show(); /* DataFrame desc = dataFrame.describe(dataFrame.columns()); dataFrame.show(); diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java index 3b1b1e1a6..5352ec10d 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java @@ -17,6 +17,7 @@ package org.datavec.spark.transform; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; @@ -24,11 +25,13 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; import org.datavec.spark.BaseSparkTest; import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; +import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; @@ -50,36 +53,35 @@ public class NormalizationTests extends BaseSparkTest { for (int i = 0; i < numColumns; i++) builder.addColumnDouble(String.valueOf(i)); + Nd4j.getRandom().setSeed(12345); + + INDArray arr = Nd4j.rand(DataType.FLOAT, 5, numColumns); for (int i = 0; i < 5; i++) { List record = new ArrayList<>(numColumns); data.add(record); for (int j = 0; j < numColumns; j++) { - record.add(new DoubleWritable(1.0)); + record.add(new DoubleWritable(arr.getDouble(i, j))); } - } - INDArray arr = RecordConverter.toMatrix(data); Schema schema = builder.build(); JavaRDD> rdd = sc.parallelize(data); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); + Dataset dataFrame = DataFrames.toDataFrame(schema, rdd); //assert equivalent to the ndarray pre processing - NormalizerStandardize standardScaler = new NormalizerStandardize(); - standardScaler.fit(new DataSet(arr.dup(), arr.dup())); - INDArray standardScalered = arr.dup(); - standardScaler.transform(new DataSet(standardScalered, standardScalered)); DataNormalization zeroToOne = new NormalizerMinMaxScaler(); zeroToOne.fit(new DataSet(arr.dup(), arr.dup())); INDArray zeroToOnes = arr.dup(); zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes)); - List rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.get().columns()); + List rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.columns()); INDArray assertion = DataFrames.toMatrix(rows); - //compare standard deviation - assertTrue(standardScaler.getStd().equalsWithEps(assertion.getRow(0), 1e-1)); + INDArray expStd = arr.std(true, true, 0); + INDArray std = assertion.getRow(0, true); + assertTrue(expStd.equalsWithEps(std, 1e-3)); //compare mean - assertTrue(standardScaler.getMean().equalsWithEps(assertion.getRow(1), 1e-1)); + INDArray expMean = arr.mean(true, 0); + assertTrue(expMean.equalsWithEps(assertion.getRow(1, true), 1e-3)); } @@ -109,10 +111,10 @@ public class NormalizationTests extends BaseSparkTest { assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema))); assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect()); - DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd); - dataFrame.get().show(); - Normalization.zeromeanUnitVariance(dataFrame).get().show(); - Normalization.normalize(dataFrame).get().show(); + Dataset dataFrame = DataFrames.toDataFrame(schema, rdd); + dataFrame.show(); + Normalization.zeromeanUnitVariance(dataFrame).show(); + Normalization.normalize(dataFrame).show(); //assert equivalent to the ndarray pre processing NormalizerStandardize standardScaler = new NormalizerStandardize(); diff --git a/datavec/pom.xml b/datavec/pom.xml index 62c722ace..7e9bf9bf0 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -330,53 +330,6 @@ - - - - test-nd4j-native - - false - - - - org.deeplearning4j - dl4j-test-resources - ${dl4j-test-resources.version} - test - - - org.nd4j - nd4j-native - ${nd4j.version} - test - - - - - - test-nd4j-cuda-10.1 - - false - - - - org.deeplearning4j - dl4j-test-resources - ${dl4j-test-resources.version} - test - - - org.nd4j - nd4j-cuda-10.1 - ${nd4j.version} - test - - - - - - @@ -391,4 +344,42 @@ + + + + test-nd4j-native + + + org.nd4j + nd4j-native + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + + diff --git a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java index 7b0d9d3c1..04773ee74 100644 --- a/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java +++ b/deeplearning4j/deeplearning4j-common/src/main/java/org/deeplearning4j/config/DL4JSystemProperties.java @@ -52,18 +52,6 @@ public class DL4JSystemProperties { */ public static final String DL4J_RESOURCES_BASE_URL_PROPERTY = "org.deeplearning4j.resources.baseurl"; - /** - * Applicability: deeplearning4j-nn
- * Description: Used for loading legacy format JSON containing custom layers. This system property is provided as an - * alternative to {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])}. Classes are specified in - * comma-separated format.
- * This is required ONLY when ALL of the following conditions are met:
- * 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND
- * 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J), AND
- * 3. You haven't already called {@code NeuralNetConfiguration#registerLegacyCustomClassesForJSON(Class[])} - */ - public static final String CUSTOM_REGISTRATION_PROPERTY = "org.deeplearning4j.config.custom.legacyclasses"; - /** * Applicability: deeplearning4j-nn
* Description: DL4J writes some crash dumps to disk when an OOM exception occurs - this functionality is enabled diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index ba9020a8d..81142fd68 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -29,23 +29,11 @@ nd4j-api ${nd4j.version}
- - org.nd4j - nd4j-native - ${nd4j.version} - test - org.nd4j nd4j-common ${nd4j.version} - - org.nd4j - nd4j-cuda-10.1 - ${nd4j.version} - test - @@ -82,7 +70,6 @@ test - org.deeplearning4j deeplearning4j-nn @@ -109,18 +96,6 @@ test - - com.google.guava - guava - ${guava.version} - - - com.google.code.findbugs - jsr305 - - - - org.nd4j nd4j-api @@ -132,8 +107,6 @@ ${commonslang.version} - - org.nd4j @@ -141,7 +114,6 @@ ${nd4j.version} - org.projectlombok lombok @@ -180,9 +152,26 @@ test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 56317b946..3bd1bd37f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.datasets.datavec; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 6ff639cb4..1e82a4783 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -17,7 +17,7 @@ package org.deeplearning4j.datasets.datavec; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.apache.commons.compress.utils.IOUtils; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java index 30ef68183..096c7ac69 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java @@ -22,12 +22,16 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import static org.junit.Assert.fail; +import java.util.Map; + +import static org.junit.Assert.*; /** * A set of tests to ensure that useful exceptions are thrown on invalid input @@ -267,23 +271,44 @@ public class TestInvalidInput extends BaseDL4JTest { //Idea: Using rnnTimeStep with a different number of examples between calls //(i.e., not calling reset between time steps) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()) - .layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); + for(String layerType : new String[]{"simple", "lstm", "graves"}) { - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + Layer l; + switch (layerType){ + case "simple": + l = new SimpleRnn.Builder().nIn(5).nOut(5).build(); + break; + case "lstm": + l = new LSTM.Builder().nIn(5).nOut(5).build(); + break; + case "graves": + l = new GravesLSTM.Builder().nIn(5).nOut(5).build(); + break; + default: + throw new RuntimeException(); + } - net.rnnTimeStep(Nd4j.create(3, 5, 10)); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + .layer(l) + .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); - try { - net.rnnTimeStep(Nd4j.create(5, 5, 10)); - fail("Expected DL4JException"); - } catch (DL4JException e) { - System.out.println("testInvalidRnnTimeStep(): " + e.getMessage()); - } catch (Exception e) { - e.printStackTrace(); - fail("Expected DL4JException"); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + net.rnnTimeStep(Nd4j.create(3, 5, 10)); + + Map m = net.rnnGetPreviousState(0); + assertNotNull(m); + assertFalse(m.isEmpty()); + + try { + net.rnnTimeStep(Nd4j.create(5, 5, 10)); + fail("Expected Exception - " + layerType); + } catch (Exception e) { +// e.printStackTrace(); + String msg = e.getMessage(); + assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch")); + } } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index c1a873dc8..decb81bb0 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -343,6 +344,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest { assertTrue(msg, gradOK); + //Also check compgraph: + ComputationGraph cg = net.toComputationGraph(); + gradOK = GradientCheckUtil.checkGradients(cg, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[]{input}, new INDArray[]{labels}); + assertTrue(msg + " - compgraph", gradOK); + TestUtils.testModelSerialization(net); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 935d83218..52d3b0774 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -16,8 +16,8 @@ package org.deeplearning4j.nn.dtypes; -import com.google.common.collect.ImmutableSet; -import com.google.common.reflect.ClassPath; +import org.nd4j.shade.guava.collect.ImmutableSet; +import org.nd4j.shade.guava.reflect.ClassPath; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; @@ -103,7 +103,7 @@ public class DTypeTests extends BaseDL4JTest { ImmutableSet info; try { //Dependency note: this ClassPath class was added in Guava 14 - info = com.google.common.reflect.ClassPath.from(DTypeTests.class.getClassLoader()) + info = org.nd4j.shade.guava.reflect.ClassPath.from(DTypeTests.class.getClassLoader()) .getTopLevelClassesRecursive("org.deeplearning4j"); } catch (IOException e) { //Should never happen diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index 297067862..751b6f6bf 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -186,7 +186,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), false, null, null, false, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, null, true, - null, CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces()).fwdPassOutput; + null, CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces(), true).fwdPassOutput; final INDArray[] fwdPassTrue = LSTMHelpers.activateHelper(lstm, lstm.conf(), new ActivationSigmoid(), lstm.input(), @@ -194,7 +194,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), false, null, null, true, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, null, true, null, - CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces()).fwdPassOutputAsArrays; + CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces(), true).fwdPassOutputAsArrays; //I have no idea what the heck this does --Ben for (int i = 0; i < timeSeriesLength; i++) { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index fb5836a99..11e45c51d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -229,7 +229,9 @@ public class TestRnnLayers extends BaseDL4JTest { net.fit(in,l); } catch (Throwable t){ String msg = t.getMessage(); - assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); + if(msg == null) + t.printStackTrace(); + assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java index 77fc63aa0..f65783bce 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.plot; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.IOUtils; @@ -498,9 +498,10 @@ public class BarnesHutTsneTest extends BaseDL4JTest { 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860, 0.6248951423054205, 0.7431868493349041}; INDArray ndinput = Nd4j.createFromArray(input).reshape(11, 5); - double[] rows = {0, 10.0000, 20.0000, 30.0000, 40.0000, 50.0000, 60.0000, 69.0000, 78.0000, 88.0000, 98.0000, 108.0000}; + int[] rows = {0, 10, 20, 30, 40, 50, 60, 69, 78, 88, 98, 108}; INDArray indRows = Nd4j.createFromArray(rows); - double[] cols = {4.0000, 3.0000, 10.0000, 8.0000, 6.0000, 7.0000, 1.0000, 5.0000, 9.0000, 2.0000, 0, 4.0000, 9.0000, 8.0000, 10.0000, 2.0000, 6.0000, 7.0000, 3.0000, 5.0000, 1.0000, 6.0000, 8.0000, 3.0000, 9.0000, 10.0000, 4.0000, 0, 5.0000, 7.0000, 0, 1.0000, 2.0000, 10.0000, 4.0000, 6.0000, 8.0000, 9.0000, 5.0000, 7.0000, 0, 1.0000, 2.0000, 3.0000, 10.0000, 8.0000, 9.0000, 6.0000, 7.0000, 5.0000, 0, 2.0000, 3.0000, 7.0000, 9.0000, 10.0000, 4.0000, 8.0000, 1.0000, 6.0000, 0, 1.0000, 2.0000, 3.0000, 4.0000, 8.0000, 10.0000, 9.0000, 5.0000, 0, 1.0000, 3.0000, 4.0000, 5.0000, 9.0000, 10.0000, 8.0000, 2.0000, 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 10.0000, 9.0000, 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 10.0000, 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000}; + int[] cols = {4, 3, 10, 8, 6, 7, 1, 5, 9, 2, 0, 4, 9, 8, 10, 2, 6, 7, 3, 5, 1, 6, 8, 3, 9, 10, 4, 0, 5, 7, 0, 1, 2, 10, 4, 6, 8, 9, + 5, 7, 0, 1, 2, 3, 10, 8, 9, 6, 7, 5, 0, 2, 3, 7, 9, 10, 4, 8, 1, 6, 0, 1, 2, 3, 4, 8, 10, 9, 5, 0, 1, 3, 4, 5, 9, 10, 8, 2, 0, 1, 2, 3, 4, 5, 6, 7, 10, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; INDArray indCols = Nd4j.createFromArray(cols); double[] vals = {0.6806, 0.1978, 0.1349, 0.0403, 0.0087, 0.0369, 0.0081, 0.0172, 0.0014, 0.0046, 0.0081, 0.3375, 0.2274, 0.0556, 0.0098, 0.0175, 0.0027, 0.0077, 0.0014, 0.0023, 0.0175, 0.6569, 0.1762, 0.0254, 0.0200, 0.0118, 0.0074, 0.0046, 0.0124, 0.0012, 0.1978, 0.0014, 0.0254, 0.7198, 0.0712, 0.0850, 0.0389, 0.0555, 0.0418, 0.0286, 0.6806, 0.3375, 0.0074, 0.0712, 0.2290, 0.0224, 0.0189, 0.0080, 0.0187, 0.0097, 0.0172, 0.0124, 0.0418, 0.7799, 0.0521, 0.0395, 0.0097, 0.0030, 0.0023, 1.706e-5, 0.0087, 0.0027, 0.6569, 0.0850, 0.0080, 0.5562, 0.0173, 0.0015, 1.706e-5, 0.0369, 0.0077, 0.0286, 0.0187, 0.7799, 0.0711, 0.0200, 0.0084, 0.0012, 0.0403, 0.0556, 0.1762, 0.0389, 0.0224, 0.0030, 0.5562, 0.0084, 0.0060, 0.0028, 0.0014, 0.2274, 0.0200, 0.0555, 0.0189, 0.0521, 0.0015, 0.0711, 0.0028, 0.3911, 0.1349, 0.0098, 0.0118, 0.7198, 0.2290, 0.0395, 0.0173, 0.0200, 0.0060, 0.3911}; INDArray indVals = Nd4j.createFromArray(vals); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index f112d4386..a66914cd7 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -22,23 +22,24 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer; -import org.junit.Before; import org.junit.Ignore; import org.junit.Test; -import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.regularization.WeightDecay; @@ -60,6 +61,9 @@ public class RegressionTest100a extends BaseDL4JTest { @Test public void testCustomLayer() throws Exception { + //We dropped support for 1.0.0-alpha and earlier custom layers due to the maintenance overhead for a rarely used feature + //An upgrade path exists as a workaround - load in beta to beta4 and re-save + //All built-in layers can be loaded going back to 0.5.0 File f = Resources.asFile("regression_testing/100a/CustomLayerExample_100a.bin"); @@ -68,67 +72,8 @@ public class RegressionTest100a extends BaseDL4JTest { fail("Expected exception"); } catch (Exception e){ String msg = e.getMessage(); - assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON")); + assertTrue(msg, msg.contains("custom") && msg.contains("1.0.0-beta") && msg.contains("saved again")); } - - NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class); - - MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - - DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer(); - assertEquals(new ActivationTanH(), l0.getActivationFn()); - assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0)); - assertEquals(new RmsProp(0.95), l0.getIUpdater()); - - CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer(); - assertEquals(new ActivationTanH(), l1.getActivationFn()); - assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); - assertEquals(new RmsProp(0.95), l1.getIUpdater()); - - - INDArray outExp; - File f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Output_100a.bin"); - try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){ - outExp = Nd4j.read(dis); - } - - INDArray in; - File f3 = Resources.asFile("regression_testing/100a/CustomLayerExample_Input_100a.bin"); - try(DataInputStream dis = new DataInputStream(new FileInputStream(f3))){ - in = Nd4j.read(dis); - } - - INDArray outAct = net.output(in); - - assertEquals(outExp, outAct); - - - //Check graph - f = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_100a.bin"); - - //Deregister custom class: - new LegacyLayerDeserializer().getLegacyNamesMap().remove("CustomLayer"); - - try { - ComputationGraph.load(f, true); - fail("Expected exception"); - } catch (Exception e){ - String msg = e.getMessage(); - assertTrue(msg, msg.contains("NeuralNetConfiguration.registerLegacyCustomClassesForJSON")); - } - - NeuralNetConfiguration.registerLegacyCustomClassesForJSON(CustomLayer.class); - - ComputationGraph graph = ComputationGraph.load(f, true); - - f2 = Resources.asFile("regression_testing/100a/CustomLayerExample_Graph_Output_100a.bin"); - try(DataInputStream dis = new DataInputStream(new FileInputStream(f2))){ - outExp = Nd4j.read(dis); - } - - outAct = graph.outputSingle(in); - - assertEquals(outExp, outAct); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java new file mode 100644 index 000000000..d1112899f --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -0,0 +1,402 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.deeplearning4j.regressiontest; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.graph.LayerVertex; +import org.deeplearning4j.nn.conf.layers.CnnLossLayer; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D; +import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.PoolingType; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; +import org.deeplearning4j.nn.conf.layers.Upsampling2D; +import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.graph.vertex.impl.MergeVertex; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInitXavier; +import org.deeplearning4j.regressiontest.customlayer100a.CustomLayer; +import org.junit.Test; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.activations.impl.ActivationReLU; +import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.activations.impl.ActivationSoftmax; +import org.nd4j.linalg.activations.impl.ActivationTanH; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.RmsProp; +import org.nd4j.linalg.learning.regularization.L2Regularization; +import org.nd4j.linalg.lossfunctions.impl.LossMAE; +import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; +import org.nd4j.resources.Resources; + +public class RegressionTest100b4 extends BaseDL4JTest { + + @Override + public DataType getDataType() { + return DataType.FLOAT; + } + + @Test + public void testCustomLayer() throws Exception { + + for (DataType dtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { + + String dtypeName = dtype.toString().toLowerCase(); + + File f = Resources.asFile("regression_testing/100b4/CustomLayerExample_100b4_" + dtypeName + ".bin"); + MultiLayerNetwork.load(f, true); + + MultiLayerNetwork net = MultiLayerNetwork.load(f, true); +// net = net.clone(); + + DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer(); + assertEquals(new ActivationTanH(), l0.getActivationFn()); + assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0)); + assertEquals(new RmsProp(0.95), l0.getIUpdater()); + + CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer(); + assertEquals(new ActivationTanH(), l1.getActivationFn()); + assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); + assertEquals(new RmsProp(0.95), l1.getIUpdater()); + + INDArray outExp; + File f2 = Resources + .asFile("regression_testing/100b4/CustomLayerExample_Output_100b4_" + dtypeName + ".bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) { + outExp = Nd4j.read(dis); + } + + INDArray in; + File f3 = Resources.asFile("regression_testing/100b4/CustomLayerExample_Input_100b4_" + dtypeName + ".bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) { + in = Nd4j.read(dis); + } + + assertEquals(dtype, in.dataType()); + assertEquals(dtype, outExp.dataType()); + assertEquals(dtype, net.params().dataType()); + assertEquals(dtype, net.getFlattenedGradients().dataType()); + assertEquals(dtype, net.getUpdater().getStateViewArray().dataType()); + + //System.out.println(Arrays.toString(net.params().data().asFloat())); + + INDArray outAct = net.output(in); + assertEquals(dtype, outAct.dataType()); + + assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); + assertEquals(dtype, net.params().dataType()); + assertEquals("Test for dtype: " + dtypeName, outExp, outAct); + } + } + + + @Test + public void testLSTM() throws Exception { + + File f = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_100b4.bin"); + MultiLayerNetwork net = MultiLayerNetwork.load(f, true); + + LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer(); + assertEquals(new ActivationTanH(), l0.getActivationFn()); + assertEquals(200, l0.getNOut()); + assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); + assertEquals(new Adam(0.005), l0.getIUpdater()); + + LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer(); + assertEquals(new ActivationTanH(), l1.getActivationFn()); + assertEquals(200, l1.getNOut()); + assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); + assertEquals(new Adam(0.005), l1.getIUpdater()); + + RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer(); + assertEquals(new ActivationSoftmax(), l2.getActivationFn()); + assertEquals(77, l2.getNOut()); + assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); + assertEquals(new Adam(0.005), l2.getIUpdater()); + + assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType()); + assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength()); + assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength()); + + INDArray outExp; + File f2 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Output_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) { + outExp = Nd4j.read(dis); + } + + INDArray in; + File f3 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Input_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) { + in = Nd4j.read(dis); + } + + INDArray outAct = net.output(in); + + assertEquals(outExp, outAct); + } + + @Test + public void testVae() throws Exception { + + File f = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_100b4.bin"); + MultiLayerNetwork net = MultiLayerNetwork.load(f, true); + + VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer(); + assertEquals(new ActivationLReLU(), l0.getActivationFn()); + assertEquals(32, l0.getNOut()); + assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); + assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); + assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); + assertEquals(new Adam(1e-3), l0.getIUpdater()); + + INDArray outExp; + File f2 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Output_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) { + outExp = Nd4j.read(dis); + } + + INDArray in; + File f3 = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_Input_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) { + in = Nd4j.read(dis); + } + + INDArray outAct = net.output(in); + + assertEquals(outExp, outAct); + } + + + @Test + public void testYoloHouseNumber() throws Exception { + + File f = Resources.asFile("regression_testing/100b4/HouseNumberDetection_100b4.bin"); + ComputationGraph net = ComputationGraph.load(f, true); + + int nBoxes = 5; + int nClasses = 10; + + ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices() + .get("convolution2d_9")).getLayerConf().getLayer(); + assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); + assertEquals(new ActivationIdentity(), cl.getActivationFn()); + assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); + assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); + assertArrayEquals(new int[]{1, 1}, cl.getKernelSize()); + + INDArray outExp; + File f2 = Resources.asFile("regression_testing/100b4/HouseNumberDetection_Output_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) { + outExp = Nd4j.read(dis); + } + + INDArray in; + File f3 = Resources.asFile("regression_testing/100b4/HouseNumberDetection_Input_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) { + in = Nd4j.read(dis); + } + + INDArray outAct = net.outputSingle(in); + + boolean eq = outExp.equalsWithEps(outAct.castTo(outExp.dataType()), 1e-3); + assertTrue(eq); + } + + @Test + public void testSyntheticCNN() throws Exception { + + File f = Resources.asFile("regression_testing/100b4/SyntheticCNN_100b4.bin"); + MultiLayerNetwork net = MultiLayerNetwork.load(f, true); + + ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).conf().getLayer(); + assertEquals(new ActivationReLU(), l0.getActivationFn()); + assertEquals(4, l0.getNOut()); + assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); + assertEquals(new Adam(0.005), l0.getIUpdater()); + assertArrayEquals(new int[]{3, 3}, l0.getKernelSize()); + assertArrayEquals(new int[]{2, 1}, l0.getStride()); + assertArrayEquals(new int[]{1, 1}, l0.getDilation()); + assertArrayEquals(new int[]{0, 0}, l0.getPadding()); + + SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).conf().getLayer(); + assertEquals(new ActivationReLU(), l1.getActivationFn()); + assertEquals(8, l1.getNOut()); + assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); + assertEquals(new Adam(0.005), l1.getIUpdater()); + assertArrayEquals(new int[]{3, 3}, l1.getKernelSize()); + assertArrayEquals(new int[]{1, 1}, l1.getStride()); + assertArrayEquals(new int[]{1, 1}, l1.getDilation()); + assertArrayEquals(new int[]{0, 0}, l1.getPadding()); + assertEquals(ConvolutionMode.Same, l1.getConvolutionMode()); + assertEquals(1, l1.getDepthMultiplier()); + + SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).conf().getLayer(); + assertArrayEquals(new int[]{3, 3}, l2.getKernelSize()); + assertArrayEquals(new int[]{2, 2}, l2.getStride()); + assertArrayEquals(new int[]{1, 1}, l2.getDilation()); + assertArrayEquals(new int[]{0, 0}, l2.getPadding()); + assertEquals(PoolingType.MAX, l2.getPoolingType()); + + ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).conf().getLayer(); + assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding()); + + Upsampling2D l4 = (Upsampling2D) net.getLayer(4).conf().getLayer(); + assertArrayEquals(new int[]{3, 3}, l4.getSize()); + + DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).conf().getLayer(); + assertEquals(new ActivationReLU(), l5.getActivationFn()); + assertEquals(16, l5.getNOut()); + assertEquals(new WeightInitXavier(), l5.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); + assertEquals(new Adam(0.005), l5.getIUpdater()); + assertArrayEquals(new int[]{3, 3}, l5.getKernelSize()); + assertArrayEquals(new int[]{1, 1}, l5.getStride()); + assertArrayEquals(new int[]{1, 1}, l5.getDilation()); + assertArrayEquals(new int[]{0, 0}, l5.getPadding()); + assertEquals(2, l5.getDepthMultiplier()); + + SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).conf().getLayer(); + assertArrayEquals(new int[]{2, 2}, l6.getKernelSize()); + assertArrayEquals(new int[]{2, 2}, l6.getStride()); + assertArrayEquals(new int[]{1, 1}, l6.getDilation()); + assertArrayEquals(new int[]{0, 0}, l6.getPadding()); + assertEquals(PoolingType.MAX, l6.getPoolingType()); + + Cropping2D l7 = (Cropping2D) net.getLayer(7).conf().getLayer(); + assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping()); + + ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).conf().getLayer(); + assertEquals(4, l8.getNOut()); + assertEquals(new WeightInitXavier(), l8.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); + assertEquals(new Adam(0.005), l8.getIUpdater()); + assertArrayEquals(new int[]{4, 4}, l8.getKernelSize()); + assertArrayEquals(new int[]{1, 1}, l8.getStride()); + assertArrayEquals(new int[]{1, 1}, l8.getDilation()); + assertArrayEquals(new int[]{0, 0}, l8.getPadding()); + + CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).conf().getLayer(); + assertEquals(new WeightInitXavier(), l9.getWeightInitFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); + assertEquals(new Adam(0.005), l9.getIUpdater()); + assertEquals(new LossMAE(), l9.getLossFn()); + + INDArray outExp; + File f2 = Resources.asFile("regression_testing/100b4/SyntheticCNN_Output_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) { + outExp = Nd4j.read(dis); + } + + INDArray in; + File f3 = Resources.asFile("regression_testing/100b4/SyntheticCNN_Input_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) { + in = Nd4j.read(dis); + } + + INDArray outAct = net.output(in); + + assertEquals(outExp, outAct); + } + + @Test + public void testSyntheticBidirectionalRNNGraph() throws Exception { + + File f = Resources.asFile("regression_testing/100b4/SyntheticBidirectionalRNNGraph_100b4.bin"); + ComputationGraph net = ComputationGraph.load(f, true); + + Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").conf().getLayer(); + + LSTM l1 = (LSTM) l0.getFwd(); + assertEquals(16, l1.getNOut()); + assertEquals(new ActivationReLU(), l1.getActivationFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); + + LSTM l2 = (LSTM) l0.getBwd(); + assertEquals(16, l2.getNOut()); + assertEquals(new ActivationReLU(), l2.getActivationFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); + + Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").conf().getLayer(); + + SimpleRnn l4 = (SimpleRnn) l3.getFwd(); + assertEquals(16, l4.getNOut()); + assertEquals(new ActivationReLU(), l4.getActivationFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l4)); + + SimpleRnn l5 = (SimpleRnn) l3.getBwd(); + assertEquals(16, l5.getNOut()); + assertEquals(new ActivationReLU(), l5.getActivationFn()); + assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); + + MergeVertex mv = (MergeVertex) net.getVertex("concat"); + + GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").conf().getLayer(); + assertEquals(PoolingType.MAX, gpl.getPoolingType()); + assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions()); + assertTrue(gpl.isCollapseDimensions()); + + OutputLayer outl = (OutputLayer) net.getLayer("out").conf().getLayer(); + assertEquals(3, outl.getNOut()); + assertEquals(new LossMCXENT(), outl.getLossFn()); + + INDArray outExp; + File f2 = Resources.asFile("regression_testing/100b4/SyntheticBidirectionalRNNGraph_Output_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f2))) { + outExp = Nd4j.read(dis); + } + + INDArray in; + File f3 = Resources.asFile("regression_testing/100b4/SyntheticBidirectionalRNNGraph_Input_100b4.bin"); + try (DataInputStream dis = new DataInputStream(new FileInputStream(f3))) { + in = Nd4j.read(dis); + } + + INDArray outAct = net.output(in)[0]; + + assertEquals(outExp, outAct); + } +} diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java index bd8726fed..3be8401d7 100755 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java @@ -22,10 +22,12 @@ import org.deeplearning4j.base.MnistFetcher; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.ResourceType; import org.deeplearning4j.datasets.mnist.MnistManager; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.util.MathUtils; import java.io.File; @@ -167,14 +169,19 @@ public class MnistDataFetcher extends BaseDataFetcher { this(true); } + private float[][] featureData = null; + @Override public void fetch(int numExamples) { if (!hasMore()) { throw new IllegalStateException("Unable to get more; there are no more images"); } - float[][] featureData = new float[numExamples][0]; - float[][] labelData = new float[numExamples][0]; + INDArray labels = Nd4j.zeros(DataType.FLOAT, numExamples, numOutcomes); + + if(featureData == null || featureData.length < numExamples){ + featureData = new float[numExamples][28*28]; + } int actualExamples = 0; byte[] working = null; @@ -202,33 +209,33 @@ public class MnistDataFetcher extends BaseDataFetcher { label--; } - float[] featureVec = new float[img.length]; - featureData[actualExamples] = featureVec; - labelData[actualExamples] = new float[numOutcomes]; - labelData[actualExamples][label] = 1.0f; + labels.put(actualExamples, label, 1.0f); - for (int j = 0; j < img.length; j++) { - float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned - if (binarize) { - if (v > 30.0f) - featureVec[j] = 1.0f; - else - featureVec[j] = 0.0f; - } else { - featureVec[j] = v / 255.0f; - } + for(int j = 0 ; j < img.length ; j++) { + featureData[actualExamples][j] = ((int) img[j]) & 0xFF; } actualExamples++; } - if (actualExamples < numExamples) { - featureData = Arrays.copyOfRange(featureData, 0, actualExamples); - labelData = Arrays.copyOfRange(labelData, 0, actualExamples); + INDArray features; + + if(featureData.length == actualExamples){ + features = Nd4j.create(featureData); + } else { + features = Nd4j.create(Arrays.copyOfRange(featureData, 0, actualExamples)); + } + + if (actualExamples < numExamples) { + labels = labels.get(NDArrayIndex.interval(0, actualExamples), NDArrayIndex.all()); + } + + if(binarize){ + features = features.gt(30.0).castTo(DataType.FLOAT); + } else { + features.divi(255.0); } - INDArray features = Nd4j.create(featureData); - INDArray labels = Nd4j.create(labelData); curr = new DataSet(features, labels); } diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java index e7fb7246a..592aefcd6 100755 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java @@ -16,8 +16,8 @@ package org.deeplearning4j.datasets.iterator; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.annotations.VisibleForTesting; +import org.nd4j.shade.guava.collect.Lists; import lombok.Getter; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java index 955732c07..02f2c4eb0 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java @@ -16,7 +16,7 @@ package org.deeplearning4j.datasets.iterator.parallel; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml index 971888fb1..c6227f27f 100644 --- a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml +++ b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml @@ -100,22 +100,31 @@ logback-classic test - - - org.nd4j - nd4j-native - ${nd4j.version} - test - test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + - diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java index dd83c1bd4..7eca0fac0 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java @@ -17,7 +17,7 @@ package org.deeplearning4j.plot; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java index 41b50795e..9efb88e24 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java @@ -16,7 +16,7 @@ package org.deeplearning4j.plot; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dimensionalityreduction.PCA; diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml index 44e81c9a6..c550547d4 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml +++ b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml @@ -286,20 +286,31 @@ ${lucene-solr.version} test - - org.nd4j - nd4j-native - ${nd4j.version} - test - test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml index 1747848f0..dec29266f 100644 --- a/deeplearning4j/deeplearning4j-modelimport/pom.xml +++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml @@ -37,6 +37,12 @@ ${nd4j.version} + + com.google.code.gson + gson + ${gson.version} + + org.deeplearning4j deeplearning4j-nn @@ -87,13 +93,6 @@ test - - org.nd4j - nd4j-native - ${nd4j.version} - test - - org.deeplearning4j deeplearning4j-datavec-iterators @@ -105,9 +104,26 @@ test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index 0f1df0714..7477c7794 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -77,26 +77,11 @@ ${project.version} - - com.google.guava - guava - ${guava.version} - com.google.protobuf protobuf-java ${google.protobuf.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - joda-time joda-time @@ -213,26 +198,31 @@ play-netty-server_2.11 ${playframework.version} - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} - - - - org.nd4j - nd4j-native - ${nd4j.version} - test - test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java deleted file mode 100644 index df178fd70..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/FunctionUtil.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nearestneighbor.server; - -import play.libs.F; -import play.mvc.Result; - -import java.util.function.Function; -import java.util.function.Supplier; - -/** - * Utility methods for Routing - * - * @author Alex Black - */ -public class FunctionUtil { - - - public static F.Function0 function0(Supplier supplier) { - return supplier::get; - } - - public static F.Function function(Function function) { - return function::apply; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java index 58682d5e1..a79b57b19 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java @@ -33,8 +33,10 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.serde.base64.Nd4jBase64; import org.nd4j.serde.binary.BinarySerde; +import play.BuiltInComponents; import play.Mode; import play.libs.Json; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; @@ -149,19 +151,36 @@ public class NearestNeighborsServer { VPTree tree = new VPTree(points, similarityFunction, invert); - RoutingDsl routingDsl = new RoutingDsl(); + //Set play secret key, if required + //http://www.playframework.com/documentation/latest/ApplicationSecret + String crypto = System.getProperty("play.crypto.secret"); + if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { + byte[] newCrypto = new byte[1024]; + + new Random().nextBytes(newCrypto); + + String base64 = Base64.getEncoder().encodeToString(newCrypto); + System.setProperty("play.crypto.secret", base64); + } + + + server = Server.forRouter(Mode.PROD, port, b -> createRouter(tree, labels, points, b)); + } + + protected Router createRouter(VPTree tree, List labels, INDArray points, BuiltInComponents builtInComponents){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); //return the host information for a given id - routingDsl.POST("/knn").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/knn").routingTo(request -> { try { - NearestNeighborRequest record = Json.fromJson(request().body().asJson(), NearestNeighborRequest.class); + NearestNeighborRequest record = Json.fromJson(request.body().asJson(), NearestNeighborRequest.class); NearestNeighbor nearestNeighbor = - NearestNeighbor.builder().points(points).record(record).tree(tree).build(); + NearestNeighbor.builder().points(points).record(record).tree(tree).build(); if (record == null) return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); NearestNeighborsResults results = - NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); + NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); return ok(Json.toJson(results)); @@ -171,11 +190,11 @@ public class NearestNeighborsServer { e.printStackTrace(); return internalServerError(e.getMessage()); } - }))); + }); - routingDsl.POST("/knnnew").routeTo(FunctionUtil.function0((() -> { + routingDsl.POST("/knnnew").routingTo(request -> { try { - Base64NDArrayBody record = Json.fromJson(request().body().asJson(), Base64NDArrayBody.class); + Base64NDArrayBody record = Json.fromJson(request.body().asJson(), Base64NDArrayBody.class); if (record == null) return badRequest(Json.toJson(Collections.singletonMap("status", "invalid json passed."))); @@ -216,23 +235,9 @@ public class NearestNeighborsServer { e.printStackTrace(); return internalServerError(e.getMessage()); } - }))); - - //Set play secret key, if required - //http://www.playframework.com/documentation/latest/ApplicationSecret - String crypto = System.getProperty("play.crypto.secret"); - if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { - byte[] newCrypto = new byte[1024]; - - new Random().nextBytes(newCrypto); - - String base64 = Base64.getEncoder().encodeToString(newCrypto); - System.setProperty("play.crypto.secret", base64); - } - - server = Server.forRouter(routingDsl.build(), Mode.PROD, port); - + }); + return routingDsl.build(); } /** diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml index f1b9e0b00..f95f9268d 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml @@ -33,12 +33,6 @@ - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.nd4j nd4j-api @@ -65,15 +59,39 @@ ${project.version} test + + + joda-time + joda-time + 2.10.3 + test + test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java index 1c57bc38a..cae103f10 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java @@ -16,8 +16,8 @@ package org.deeplearning4j.clustering.info; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Table; +import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Table; import org.deeplearning4j.clustering.cluster.Cluster; import org.deeplearning4j.clustering.cluster.ClusterSet; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java index c2154e6ba..f1cc2e304 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.quadtree; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java index 11746f4c2..5c31ee78a 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.randomprojection; -import com.google.common.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java index 83af0365a..659f334df 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.sptree; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.val; import org.deeplearning4j.clustering.algorithm.Distance; import org.deeplearning4j.nn.conf.WorkspaceMode; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java index 4b7b8e567..1de7a379b 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.kdtree; -import com.google.common.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.deeplearning4j.clustering.BaseDL4JTest; import org.joda.time.Duration; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java index fe4fac1b7..c9140942d 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java @@ -269,9 +269,9 @@ public class KMeansTest extends BaseDL4JTest { double[] centroid1 = {2.44e8, 2.71e8, 2.98e8, 3.25e8}; double[] centroid2 = {5.14e8, 5.41e8, 5.68e8, 5.95e8}; - double[] centroid3 = {1.63e8, 1.9e8, 2.17e8, 2.44e8}; - double[] centroid4 = {6.76e8, 7.03e8, 7.3e8, 7.57e8}; - double[] centroid5 = {4.06e8, 4.33e8, 4.6e8, 4.87e8}; + double[] centroid3 = {1000000.0, 2.8E7, 5.5E7, 8.2E7}; + double[] centroid4 = {7.03E8, 7.3E8, 7.57E8, 7.84E8}; + double[] centroid5 = {3.79E8, 4.06E8, 4.33E8, 4.6E8}; assertArrayEquals(centroid1, clusterSet.getClusters().get(0).getCenter().getArray().toDoubleVector(), 1e-4); assertArrayEquals(centroid2, clusterSet.getClusters().get(1).getCenter().getArray().toDoubleVector(), 1e-4); diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java index 03ad90748..f5ee19403 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.clustering.sptree; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import org.apache.commons.lang3.time.StopWatch; import org.deeplearning4j.clustering.BaseDL4JTest; import org.junit.Before; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java index 76658438a..5edb3926a 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java @@ -18,12 +18,10 @@ package org.deeplearning4j.clustering.vptree; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.clustering.BaseDL4JTest; import org.deeplearning4j.clustering.sptree.DataPoint; import org.joda.time.Duration; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,7 +31,6 @@ import org.nd4j.linalg.primitives.Counter; import org.nd4j.linalg.primitives.Pair; import java.util.*; -import java.util.concurrent.TimeUnit; import static org.junit.Assert.*; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java index 67169132c..03d2462a5 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java @@ -16,7 +16,7 @@ package org.deeplearning4j.text.corpora.sentiwordnet; -import com.google.common.collect.Sets; +import org.nd4j.shade.guava.collect.Sets; import org.apache.uima.analysis_engine.AnalysisEngine; import org.apache.uima.cas.CAS; import org.apache.uima.cas.CASException; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index d2d752509..f4dd1a6c5 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -16,8 +16,8 @@ package org.deeplearning4j.models; -import com.google.common.io.Files; -import com.google.common.primitives.Doubles; +import org.nd4j.shade.guava.io.Files; +import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.ArrayUtils; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 61e31b3c7..01b38a644 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -16,8 +16,8 @@ package org.deeplearning4j.models.word2vec; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Ints; import lombok.val; import net.didion.jwnl.data.Word; import org.apache.commons.io.FileUtils; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java index 5be56964c..579caa0a3 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTable.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.inmemory; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java index df502aded..84fc17b7e 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.reader.impl; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NonNull; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java index 75511cae1..f71c56717 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.wordvectors; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java index f97aee9c0..8680a809c 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/count/CountMap.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.glove.count; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.nd4j.linalg.primitives.Pair; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java index 4c05b7cc7..64ee79dd4 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.paragraphvectors; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import com.google.gson.JsonObject; import com.google.gson.JsonParser; import lombok.Getter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java index 78a878930..87dd0880a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java @@ -16,8 +16,8 @@ package org.deeplearning4j.models.sequencevectors; -import com.google.common.primitives.Ints; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java index d5e789ebf..99263f6bc 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/SequenceElement.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.sequencevectors.sequence; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java index cdb4b6c9e..5e51cef20 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/invertedindex/InvertedIndex.java @@ -16,7 +16,7 @@ package org.deeplearning4j.text.invertedindex; -import com.google.common.base.Function; +import org.nd4j.shade.guava.base.Function; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.nd4j.linalg.primitives.Pair; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java index 41536ff70..f5e0ca388 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImplTest.java @@ -16,7 +16,7 @@ package org.deeplearning4j.models.embeddings.wordvectors; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index cd3737e04..731a9cd60 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -103,8 +103,7 @@ public class FastTextTest extends BaseDL4JTest { } @Test - @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") - public void testPredict() throws IOException { + public void testPredict() { String text = "I like soccer"; FastText fastText = new FastText(supModelFile); @@ -119,8 +118,7 @@ public class FastTextTest extends BaseDL4JTest { } @Test - @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") - public void testPredictProbability() throws IOException { + public void testPredictProbability() { String text = "I like soccer"; FastText fastText = new FastText(supModelFile); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/resources/models/fasttext/supervised.model.bin b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/resources/models/fasttext/supervised.model.bin deleted file mode 100644 index dcf564684..000000000 Binary files a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/resources/models/fasttext/supervised.model.bin and /dev/null differ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java index 5834e1647..69767df13 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ConfusionMatrix.java @@ -16,8 +16,8 @@ package org.deeplearning4j.eval; -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; +import org.nd4j.shade.guava.collect.HashMultiset; +import org.nd4j.shade.guava.collect.Multiset; import lombok.Getter; import java.io.Serializable; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java index e964cd9b0..2b00ac375 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/PrecisionRecallCurve.java @@ -16,7 +16,7 @@ package org.deeplearning4j.eval.curves; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java index b66230ddd..5d5e65c2a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/curves/RocCurve.java @@ -16,7 +16,7 @@ package org.deeplearning4j.eval.curves; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Data; import lombok.EqualsAndHashCode; import org.nd4j.shade.jackson.annotation.JsonProperty; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index 6cd8f06b3..5a5ce5665 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.OutputLayerUtil; @@ -40,6 +41,7 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -172,6 +174,26 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { ComputationGraphConfiguration conf; try { conf = mapper.readValue(json, ComputationGraphConfiguration.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try{ + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, ComputationGraphConfiguration.class); + } catch (InvalidTypeIdException e2){ + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work + String msg = e2.getMessage(); + if(msg != null && msg.contains("Could not resolve type id")){ + throw new RuntimeException("Error deserializing ComputationGraphConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + + " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e); + } + throw new RuntimeException(e2); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (Exception e) { //Check if this exception came from legacy deserializer... String msg = e.getMessage(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java index ab7fd044b..4f02e3d66 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/InputPreProcessor.java @@ -19,7 +19,6 @@ package org.deeplearning4j.nn.conf; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyPreprocessorDeserializerHelper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -34,8 +33,7 @@ import java.io.Serializable; * * @author Adam Gibson */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyPreprocessorDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface InputPreProcessor extends Serializable, Cloneable { /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java index de3373323..52a67ef16 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; @@ -41,6 +42,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossMSE; import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.databind.node.ArrayNode; import java.io.IOException; @@ -157,6 +159,26 @@ public class MultiLayerConfiguration implements Serializable, Cloneable { ObjectMapper mapper = NeuralNetConfiguration.mapper(); try { conf = mapper.readValue(json, MultiLayerConfiguration.class); + } catch (InvalidTypeIdException e){ + if(e.getMessage().contains("@class")){ + try { + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class); + } catch (InvalidTypeIdException e2){ + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work + String msg = e2.getMessage(); + if(msg != null && msg.contains("Could not resolve type id")){ + throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + + " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e); + } + throw new RuntimeException(e2); + } catch (IOException e2){ + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); } catch (IOException e) { //Check if this exception came from legacy deserializer... String msg = e.getMessage(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index d8e77a55a..0da5bea13 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -26,20 +26,14 @@ import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; -import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyPreprocessorDeserializer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyReconstructionDistributionDeserializer; import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.nn.weights.IWeightInit; @@ -59,10 +53,6 @@ import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.L2Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.WeightDecay; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.primitives.Pair; -import org.nd4j.serde.json.LegacyIActivationDeserializer; -import org.nd4j.serde.json.LegacyILossFunctionDeserializer; import org.nd4j.shade.jackson.databind.ObjectMapper; import java.io.IOException; @@ -342,9 +332,7 @@ public class NeuralNetConfiguration implements Serializable, Cloneable { ObjectMapper mapper = mapper(); try { - String ret = mapper.writeValueAsString(this); - return ret; - + return mapper.writeValueAsString(this); } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) { throw new RuntimeException(e); } @@ -384,86 +372,6 @@ public class NeuralNetConfiguration implements Serializable, Cloneable { return JsonMappers.getMapper(); } - /** - * Set of classes that can be registered for legacy deserialization. - */ - private static List> REGISTERABLE_CUSTOM_CLASSES = (List>) Arrays.>asList( - Layer.class, - GraphVertex.class, - InputPreProcessor.class, - IActivation.class, - ILossFunction.class, - ReconstructionDistribution.class - ); - - /** - * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution - * ONLY) for JSON deserialization.
- *
- * This is required ONLY when BOTH of the following conditions are met:
- * 1. You want to load a serialized net, saved in 1.0.0-alpha or before, AND
- * 2. The serialized net has a custom Layer, GraphVertex, etc (i.e., one not defined in DL4J)
- *
- * By passing the classes of these layers here, DL4J should be able to deserialize them, in spite of the JSON - * format change between versions. - * - * @param classes Classes to register - */ - public static void registerLegacyCustomClassesForJSON(Class... classes) { - registerLegacyCustomClassesForJSONList(Arrays.>asList(classes)); - } - - /** - * @see #registerLegacyCustomClassesForJSON(Class[]) - */ - public static void registerLegacyCustomClassesForJSONList(List> classes){ - //Default names (i.e., old format for custom JSON format) - List> list = new ArrayList<>(); - for(Class c : classes){ - list.add(new Pair(c.getSimpleName(), c)); - } - registerLegacyCustomClassesForJSON(list); - } - - /** - * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution - * ONLY) for JSON deserialization, with custom names.
- * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])} - * but is added in case it is required in non-standard circumstances. - */ - public static void registerLegacyCustomClassesForJSON(List> classes){ - for(Pair p : classes){ - String s = p.getFirst(); - Class c = p.getRight(); - //Check if it's a valid class to register... - boolean found = false; - for( Class c2 : REGISTERABLE_CUSTOM_CLASSES){ - if(c2.isAssignableFrom(c)){ - if(c2 == Layer.class){ - LegacyLayerDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == GraphVertex.class){ - LegacyGraphVertexDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == InputPreProcessor.class){ - LegacyPreprocessorDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == IActivation.class ){ - LegacyIActivationDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == ILossFunction.class ){ - LegacyILossFunctionDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } else if(c2 == ReconstructionDistribution.class){ - LegacyReconstructionDistributionDeserializer.registerLegacyClassSpecifiedName(s, (Class)c); - } - - found = true; - } - } - - if(!found){ - throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " + - c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES); - } - } - } - /** * NeuralNetConfiguration builder, used as a starting point for creating a MultiLayerConfiguration or * ComputationGraphConfiguration.
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java index f9af153ad..ee8bbdc64 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java @@ -18,6 +18,8 @@ package org.deeplearning4j.nn.conf.dropout; import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -26,11 +28,11 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; +import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import org.nd4j.shade.jackson.annotation.JsonProperty; -import org.nd4j.util.OneTimeLogger; /** * Implements standard (inverted) dropout.
@@ -64,17 +66,29 @@ import org.nd4j.util.OneTimeLogger; * @author Alex Black */ @Data -@JsonIgnoreProperties({"mask", "helper"}) -@EqualsAndHashCode(exclude = {"mask", "helper"}) +@JsonIgnoreProperties({"mask", "helper", "helperCountFail"}) +@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"}) @Slf4j public class Dropout implements IDropout { + /** + * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? + * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in + * (non-CuDNN) implementation for LSTM/GravesLSTM will be used + * + */ + @Getter + @Setter + protected boolean helperAllowFallback = true; + private double p; private ISchedule pSchedule; private transient INDArray mask; private transient DropoutHelper helper; private boolean initializedHelper = false; + private int helperCountFail = 0; + /** * @param activationRetainProbability Probability of retaining an activation - see {@link Dropout} javadoc */ @@ -96,6 +110,18 @@ public class Dropout implements IDropout { this(Double.NaN, activationRetainProbabilitySchedule); } + /** + * When using a helper (CuDNN or MKLDNN in some cases) and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If false, the built-in + * (non-helper) implementation for Dropout will be used + * + * @param allowFallback Whether fallback to non-helper implementation should be used + */ + public Dropout helperAllowFallback(boolean allowFallback) { + this.setHelperAllowFallback(allowFallback); + return this; + } + protected Dropout(@JsonProperty("p") double activationRetainProbability, @JsonProperty("pSchedule") ISchedule activationRetainProbabilitySchedule) { this.p = activationRetainProbability; this.pSchedule = activationRetainProbabilitySchedule; @@ -141,9 +167,29 @@ public class Dropout implements IDropout { initializeHelper(output.dataType()); } - if(helper != null){ - helper.applyDropout(inputActivations, output, p); - return output; + if(helper != null && (helperCountFail == 0 || !isHelperAllowFallback())){ + boolean helperWorked = false; + try { + helper.applyDropout(inputActivations, output, p); + helperWorked = true; + }catch (ND4JOpProfilerException e){ + throw e; //NaN panic etc for debugging + } catch (Exception e){ + if(e.getMessage().contains("Failed to allocate")){ + //This is a memory exception - don't fallback to built-in implementation + throw e; + } + + if(isHelperAllowFallback()){ + helperCountFail++; + log.warn("CuDNN execution failed - falling back on built-in implementation",e); + } else { + throw new RuntimeException("Error during Dropout CuDNN helper forward pass - helperAllowFallback() is set to false", e); + } + } + + if(helperWorked) + return output; } INDArray inputCast = inputActivations; @@ -159,9 +205,29 @@ public class Dropout implements IDropout { @Override public INDArray backprop(INDArray gradAtOutput, INDArray gradAtInput, int iteration, int epoch) { - if(helper != null){ - helper.backprop(gradAtOutput, gradAtInput); - return gradAtInput; + if(helper != null && (helperCountFail == 0 || !isHelperAllowFallback())){ + boolean helperWorked = false; + try { + helper.backprop(gradAtOutput, gradAtInput); + helperWorked = true; + }catch (ND4JOpProfilerException e){ + throw e; //NaN panic etc for debugging + } catch (Exception e){ + if(e.getMessage().contains("Failed to allocate")){ + //This is a memory exception - don't fallback to built-in implementation + throw e; + } + + if(isHelperAllowFallback()){ + helperCountFail++; + log.warn("CuDNN execution failed - falling back on built-in implementation",e); + } else { + throw new RuntimeException("Error during Dropout CuDNN helper backprop - helperAllowFallback() is set to false", e); + } + } + + if(helperWorked) + return gradAtInput; } Preconditions.checkState(mask != null, "Cannot perform backprop: Dropout mask array is absent (already cleared?)"); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java index 2cd8d83f4..6ca3b35de 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java @@ -15,7 +15,7 @@ ******************************************************************************/ package org.deeplearning4j.nn.conf.graph; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.inputs.InputType; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java index e4968a49e..497cc77f5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java @@ -19,7 +19,6 @@ package org.deeplearning4j.nn.conf.graph; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyGraphVertexDeserializerHelper; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,8 +32,7 @@ import java.io.Serializable; * * @author Alex Black */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyGraphVertexDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public abstract class GraphVertex implements Cloneable, Serializable { @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index 805e1729c..85da86fa2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -16,12 +16,15 @@ package org.deeplearning4j.nn.conf.inputs; -import lombok.*; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonIgnore; import org.nd4j.shade.jackson.annotation.JsonInclude; -import org.nd4j.shade.jackson.annotation.JsonSubTypes; +import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import java.io.Serializable; @@ -36,12 +39,7 @@ import java.util.Arrays; * @author Alex Black */ @JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) -@JsonSubTypes(value = {@JsonSubTypes.Type(value = InputType.InputTypeFeedForward.class, name = "FeedForward"), - @JsonSubTypes.Type(value = InputType.InputTypeRecurrent.class, name = "Recurrent"), - @JsonSubTypes.Type(value = InputType.InputTypeConvolutional.class, name = "Convolutional"), - @JsonSubTypes.Type(value = InputType.InputTypeConvolutionalFlat.class, name = "ConvolutionalFlat"), - @JsonSubTypes.Type(value = InputType.InputTypeConvolutional3D.class, name = "Convolutional3D")}) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public abstract class InputType implements Serializable { /** @@ -174,13 +172,16 @@ public abstract class InputType implements Serializable { } - @AllArgsConstructor - @Getter @NoArgsConstructor + @Getter @EqualsAndHashCode(callSuper = false) public static class InputTypeFeedForward extends InputType { private long size; + public InputTypeFeedForward(@JsonProperty("size") long size) { + this.size = size; + } + @Override public Type getType() { return Type.FF; @@ -203,9 +204,8 @@ public abstract class InputType implements Serializable { } } - @Getter @NoArgsConstructor - @AllArgsConstructor + @Getter @EqualsAndHashCode(callSuper = false) public static class InputTypeRecurrent extends InputType { private long size; @@ -215,6 +215,11 @@ public abstract class InputType implements Serializable { this(size, -1); } + public InputTypeRecurrent(@JsonProperty("size") long size, @JsonProperty("timeSeriesLength") long timeSeriesLength) { + this.size = size; + this.timeSeriesLength = timeSeriesLength; + } + @Override public Type getType() { return Type.RNN; @@ -245,15 +250,19 @@ public abstract class InputType implements Serializable { } } - @AllArgsConstructor + @NoArgsConstructor @Data @EqualsAndHashCode(callSuper = false) - @NoArgsConstructor public static class InputTypeConvolutional extends InputType { private long height; private long width; private long channels; + public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) { + this.height = height; + this.width = width; + this.channels = channels; + } /** * Return the number of channels / depth for this 2D convolution. This method has been deprecated, @@ -298,10 +307,9 @@ public abstract class InputType implements Serializable { } } - @AllArgsConstructor + @NoArgsConstructor @Data @EqualsAndHashCode(callSuper = false) - @NoArgsConstructor public static class InputTypeConvolutional3D extends InputType { private Convolution3D.DataFormat dataFormat; private long depth; @@ -309,6 +317,15 @@ public abstract class InputType implements Serializable { private long width; private long channels; + public InputTypeConvolutional3D(@JsonProperty("dataFormat") Convolution3D.DataFormat dataFormat, + @JsonProperty("depth") long depth, @JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) { + this.dataFormat = dataFormat; + this.depth = depth; + this.height = height; + this.width = width; + this.channels = channels; + } + @Override public Type getType() { return Type.CNN3D; @@ -336,15 +353,20 @@ public abstract class InputType implements Serializable { } } - @AllArgsConstructor + @NoArgsConstructor @Data @EqualsAndHashCode(callSuper = false) - @NoArgsConstructor public static class InputTypeConvolutionalFlat extends InputType { private long height; private long width; private long depth; + public InputTypeConvolutionalFlat(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("depth") long depth) { + this.height = height; + this.width = width; + this.depth = depth; + } + @Override public Type getType() { return Type.CNNFlat; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java index 82bda5647..b051c4b36 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java @@ -17,8 +17,6 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; -import org.deeplearning4j.nn.params.LSTMParamInitializer; -import org.deeplearning4j.nn.weights.WeightInit; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; @@ -35,11 +33,13 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer { protected double forgetGateBiasInit; protected IActivation gateActivationFn = new ActivationSigmoid(); + protected boolean helperAllowFallback = true; protected AbstractLSTM(Builder builder) { super(builder); this.forgetGateBiasInit = builder.forgetGateBiasInit; this.gateActivationFn = builder.gateActivationFn; + this.helperAllowFallback = builder.helperAllowFallback; } @AllArgsConstructor @@ -60,6 +60,14 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer { */ protected IActivation gateActivationFn = new ActivationSigmoid(); + /** + * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? + * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in + * (non-CuDNN) implementation for LSTM/GravesLSTM will be used + * + */ + protected boolean helperAllowFallback = true; + /** * Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or longer-term * dependencies. @@ -100,6 +108,18 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer { return (T) this; } + /** + * When using a helper (CuDNN or MKLDNN in some cases) and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If false, the built-in + * (non-helper) implementation for LSTM/GravesLSTM will be used + * + * @param allowFallback Whether fallback to non-helper implementation should be used + */ + public T helperAllowFallback(boolean allowFallback) { + this.setHelperAllowFallback(allowFallback); + return (T) this; + } + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index 53c00acac..4c470fec5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -428,16 +428,31 @@ public class BatchNormalization extends FeedForwardLayer { /** * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? - * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in + * If set to false, an exception in CuDNN will be propagated back to the user. If true, the built-in * (non-CuDNN) implementation for BatchNormalization will be used * + * @deprecated Use {@link #helperAllowFallback(boolean)} + * * @param allowFallback Whether fallback to non-CuDNN implementation should be used */ + @Deprecated public Builder cudnnAllowFallback(boolean allowFallback) { this.setCudnnAllowFallback(allowFallback); return this; } + /** + * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in + * (non-MKL/CuDNN) implementation for BatchNormalizationLayer will be used + * + * @param allowFallback Whether fallback to non-CuDNN implementation should be used + */ + public Builder helperAllowFallback(boolean allowFallback) { + this.cudnnAllowFallback = allowFallback; + return this; + } + /** * How should the moving average of variance be stored? Two different parameterizations are supported. * useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is stored directly as diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 3d2e35d24..4fdf1e9cc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -533,14 +533,29 @@ public class ConvolutionLayer extends FeedForwardLayer { /** * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? - * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in + * If set to false, an exception in CuDNN will be propagated back to the user. If true, the built-in * (non-CuDNN) implementation for ConvolutionLayer will be used * + * @deprecated Use {@link #helperAllowFallback(boolean)} + * * @param allowFallback Whether fallback to non-CuDNN implementation should be used */ + @Deprecated public T cudnnAllowFallback(boolean allowFallback) { this.setCudnnAllowFallback(allowFallback); return (T) this; } + + /** + * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in + * (non-MKL/CuDNN) implementation for ConvolutionLayer will be used + * + * @param allowFallback Whether fallback to non-CuDNN implementation should be used + */ + public T helperAllowFallback(boolean allowFallback) { + this.cudnnAllowFallback = allowFallback; + return (T) this; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java index d7aa869a1..1a2a89a24 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java @@ -53,11 +53,13 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer { private double forgetGateBiasInit; private IActivation gateActivationFn = new ActivationSigmoid(); + protected boolean helperAllowFallback = true; private GravesBidirectionalLSTM(Builder builder) { super(builder); this.forgetGateBiasInit = builder.forgetGateBiasInit; this.gateActivationFn = builder.gateActivationFn; + this.helperAllowFallback = builder.helperAllowFallback; initializeConstraints(builder); } @@ -123,6 +125,14 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer { */ private IActivation gateActivationFn = new ActivationSigmoid(); + /** + * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? + * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in + * (non-CuDNN) implementation for GravesBidirectionalLSTM will be used + * + */ + protected boolean helperAllowFallback = true; + /** * Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or longer-term * dependencies. @@ -163,6 +173,18 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer { return this; } + /** + * When using a helper (CuDNN or MKLDNN in some cases) and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If false, the built-in + * (non-helper) implementation for GravesBidirectionalLSTM will be used + * + * @param allowFallback Whether fallback to non-helper implementation should be used + */ + public Builder helperAllowFallback(boolean allowFallback) { + this.setHelperAllowFallback(allowFallback); + return (Builder) this; + } + @SuppressWarnings("unchecked") public GravesBidirectionalLSTM build() { return new GravesBidirectionalLSTM(this); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java index 5dfb3b671..25577bd1f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java @@ -29,7 +29,6 @@ import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializerHelper; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -45,8 +44,7 @@ import java.util.*; * A neural network layer. */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyLayerDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @Data @NoArgsConstructor public abstract class Layer implements TrainingConfig, Serializable, Cloneable { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index dfc2df9c8..b16703569 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -238,16 +238,31 @@ public class LocalResponseNormalization extends Layer { /** * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? - * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in + * If set to false, an exception in CuDNN will be propagated back to the user. If true, the built-in * (non-CuDNN) implementation for BatchNormalization will be used * + * @deprecated Use {@link #helperAllowFallback(boolean)} + * * @param allowFallback Whether fallback to non-CuDNN implementation should be used */ + @Deprecated public Builder cudnnAllowFallback(boolean allowFallback) { this.setCudnnAllowFallback(allowFallback); return this; } + /** + * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in + * (non-MKL/CuDNN) implementation for LocalResponseNormalizationLayer will be used + * + * @param allowFallback Whether fallback to non-CuDNN implementation should be used + */ + public Builder helperAllowFallback(boolean allowFallback) { + this.cudnnAllowFallback = allowFallback; + return this; + } + @Override public LocalResponseNormalization build() { return new LocalResponseNormalization(this); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 877e216da..0d0ccba9b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -455,15 +455,30 @@ public class Subsampling3DLayer extends NoParamLayer { /** * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? - * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in + * If set to false, an exception in CuDNN will be propagated back to the user. If true, the built-in * (non-CuDNN) implementation for ConvolutionLayer will be used * + * @deprecated Use {@link #helperAllowFallback(boolean)} + * * @param allowFallback Whether fallback to non-CuDNN implementation should be used */ + @Deprecated public T cudnnAllowFallback(boolean allowFallback) { this.setCudnnAllowFallback(allowFallback); return (T) this; } + + /** + * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in + * (non-MKL/CuDNN) implementation for Subsampling3DLayer will be used + * + * @param allowFallback Whether fallback to non-CuDNN implementation should be used + */ + public T helperAllowFallback(boolean allowFallback) { + this.cudnnAllowFallback = allowFallback; + return (T) this; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index b2e4df6b8..c20526cf1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -480,17 +480,32 @@ public class SubsamplingLayer extends NoParamLayer { } /** - * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed? - * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in - * (non-CuDNN) implementation for ConvolutionLayer will be used + * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in + * (non-MKL/CuDNN) implementation for ConvolutionLayer will be used + * + * @deprecated Use {@link #helperAllowFallback(boolean)} * * @param allowFallback Whether fallback to non-CuDNN implementation should be used */ + @Deprecated public T cudnnAllowFallback(boolean allowFallback) { this.cudnnAllowFallback = allowFallback; return (T) this; } + /** + * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? + * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in + * (non-MKL/CuDNN) implementation for SubsamplingLayer will be used + * + * @param allowFallback Whether fallback to non-CuDNN implementation should be used + */ + public T helperAllowFallback(boolean allowFallback) { + this.cudnnAllowFallback = allowFallback; + return (T) this; + } + /** * When doing average pooling, should the padding values be included in the divisor or not?
* Not applicable for max and p-norm pooling.
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index 1e5eb11f2..0f1a770a8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -22,7 +22,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyIntArrayDeserializer; +import org.deeplearning4j.nn.conf.serde.legacy.LegacyIntArrayDeserializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.linalg.api.buffer.DataType; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java index 7b2317a8f..f72da09e5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java @@ -27,7 +27,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.conf.serde.FrozenLayerDeserializer; import org.deeplearning4j.nn.params.FrozenLayerParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -48,7 +47,6 @@ import java.util.List; * @author Alex Black */ @EqualsAndHashCode(callSuper = false) -@JsonDeserialize(using = FrozenLayerDeserializer.class) public class FrozenLayer extends Layer { @Getter diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java index 0de9143fe..f58e6b0e7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java @@ -16,7 +16,6 @@ package org.deeplearning4j.nn.conf.layers.variational; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyReconstructionDistributionDeserializerHelper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonSubTypes; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -32,8 +31,7 @@ import java.io.Serializable; * * @author Alex Black */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", - defaultImpl = LegacyReconstructionDistributionDeserializerHelper.class) +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface ReconstructionDistribution extends Serializable { /** diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java index 5194da4a1..d32488363 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java @@ -21,12 +21,18 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.weights.*; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.learning.config.*; import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.WeightDecay; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.impl.*; import org.nd4j.shade.jackson.core.JsonParser; import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.databind.DeserializationContext; @@ -38,6 +44,8 @@ import org.nd4j.shade.jackson.databind.node.ObjectNode; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; /** * A custom (abstract) deserializer that handles backward compatibility (currently only for updater refactoring that @@ -103,6 +111,24 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return false; } + protected boolean requiresActivationFromLegacy(Layer[] layers){ + for(Layer l : layers){ + if(l instanceof BaseLayer && ((BaseLayer)l).getActivationFn() == null){ + return true; + } + } + return false; + } + + protected boolean requiresLegacyLossHandling(Layer[] layers){ + for(Layer l : layers){ + if(l instanceof BaseOutputLayer && ((BaseOutputLayer)l).getLossFn() == null){ + return true; + } + } + return false; + } + protected void handleUpdaterBackwardCompatibility(BaseLayer layer, ObjectNode on){ if(on != null && on.has("updater")){ String updaterName = on.get("updater").asText(); @@ -220,7 +246,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im } protected void handleWeightInitBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){ - if(on != null && (on.has("weightInit") )){ + if(on != null && on.has("weightInit") ){ //Legacy format JSON if(on.has("weightInit")){ String wi = on.get("weightInit").asText(); @@ -228,8 +254,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im WeightInit w = WeightInit.valueOf(wi); Distribution d = null; if(w == WeightInit.DISTRIBUTION && on.has("dist")){ - //TODO deserialize distribution - String dist = on.get("dist").asText(); + String dist = on.get("dist").toString(); d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class); } IWeightInit iwi = w.getWeightInitFunction(d); @@ -241,6 +266,57 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im } } + //Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" : + protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){ + + if(baseLayer.getActivationFn() == null && on.has("activationFunction")){ + String afn = on.get("activationFunction").asText(); + IActivation a = null; + try { + a = getMap().get(afn.toLowerCase()).newInstance(); + } catch (InstantiationException | IllegalAccessException e){ + //Ignore + } + baseLayer.setActivationFn(a); + } + } + + //0.5.0 and earlier: loss function was an enum like "lossFunction" : "NEGATIVELOGLIKELIHOOD", + protected void handleLossBackwardCompatibility(BaseOutputLayer baseLayer, ObjectNode on){ + if(baseLayer.getLossFn() == null && on.has("activationFunction")) { + String lfn = on.get("lossFunction").asText(); + ILossFunction loss = null; + switch (lfn) { + case "MCXENT": + loss = new LossMCXENT(); + break; + case "MSE": + loss = new LossMSE(); + break; + case "NEGATIVELOGLIKELIHOOD": + loss = new LossNegativeLogLikelihood(); + break; + case "SQUARED_LOSS": + loss = new LossL2(); + break; + case "XENT": + loss = new LossBinaryXENT(); + } + baseLayer.setLossFn(loss); + } + } + + private static Map> activationMap; + private static synchronized Map> getMap(){ + if(activationMap == null){ + activationMap = new HashMap<>(); + for(Activation a : Activation.values()){ + activationMap.put(a.toString().toLowerCase(), a.getActivationFunction().getClass()); + } + } + return activationMap; + } + @Override public void resolve(DeserializationContext ctxt) throws JsonMappingException { ((ResolvableDeserializer) defaultDeserializer).resolve(ctxt); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java index 1c30b053b..50384e518 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; @@ -74,6 +75,8 @@ public class ComputationGraphConfigurationDeserializer boolean attemptIUpdaterFromLegacy = requiresIUpdaterFromLegacy(layers); boolean requireLegacyRegularizationHandling = requiresRegularizationFromLegacy(layers); boolean requiresLegacyWeightInitHandling = requiresWeightInitFromLegacy(layers); + boolean requiresLegacyActivationHandling = requiresActivationFromLegacy(layers); + boolean requiresLegacyLossHandling = requiresLegacyLossHandling(layers); Long charOffsetEnd = null; JsonLocation endLocation = null; @@ -123,6 +126,14 @@ public class ComputationGraphConfigurationDeserializer handleWeightInitBackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); } + if(requiresLegacyActivationHandling && layers[layerIdx] instanceof BaseLayer && ((BaseLayer)layers[layerIdx]).getActivationFn() == null){ + handleActivationBackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); + } + + if(requiresLegacyLossHandling && layers[layerIdx] instanceof BaseOutputLayer && ((BaseOutputLayer)layers[layerIdx]).getLossFn() == null){ + handleLossBackwardCompatibility((BaseOutputLayer) layers[layerIdx], (ObjectNode)next); + } + if(layers[layerIdx].getIDropout() == null){ //Check for legacy dropout if(next.has("dropOut")){ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/FrozenLayerDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/FrozenLayerDeserializer.java deleted file mode 100644 index 0eb618e20..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/FrozenLayerDeserializer.java +++ /dev/null @@ -1,58 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde; - -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.deeplearning4j.nn.conf.serde.legacyformat.LegacyLayerDeserializer; -import org.nd4j.shade.jackson.core.JsonFactory; -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * A custom deserializer for handling Frozen layers - * This is used to handle the 2 different Layer JSON formats - old/legacy, and current - * - * @author Alex Black - */ -public class FrozenLayerDeserializer extends JsonDeserializer { - @Override - public Layer deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode n = jp.getCodec().readTree(jp); - JsonNode layer = n.get("layer"); - boolean newFormat = layer.has("@class"); - - String internalText = layer.toString(); - Layer internal; - if(newFormat){ - //Standard/new format - internal = NeuralNetConfiguration.mapper().readValue(internalText, Layer.class); - } else { - //Legacy format - JsonFactory factory = new JsonFactory(); - JsonParser parser = factory.createParser(internalText); - parser.setCodec(jp.getCodec()); - internal = new LegacyLayerDeserializer().deserialize(parser, deserializationContext); - } - return new FrozenLayer(internal); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java index a4e29c86e..1cdd85d4b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; +import org.deeplearning4j.nn.conf.serde.legacy.LegacyJsonFormat; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.serde.json.LegacyIActivationDeserializer; @@ -42,8 +43,10 @@ import org.nd4j.shade.jackson.databind.introspect.JacksonAnnotationIntrospector; import org.nd4j.shade.jackson.databind.jsontype.TypeResolverBuilder; import org.nd4j.shade.jackson.databind.module.SimpleModule; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; +import org.nd4j.util.OneTimeLogger; import java.lang.annotation.Annotation; +import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -56,93 +59,14 @@ import java.util.List; @Slf4j public class JsonMappers { - /** - * @deprecated Use {@link DL4JSystemProperties#CUSTOM_REGISTRATION_PROPERTY} - */ - @Deprecated - public static String CUSTOM_REGISTRATION_PROPERTY = DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY; - - static { - String p = System.getProperty(DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY); - if(p != null && !p.isEmpty()){ - String[] split = p.split(","); - List> list = new ArrayList<>(); - for(String s : split){ - try{ - Class c = Class.forName(s); - list.add(c); - } catch (Throwable t){ - log.warn("Error parsing {} system property: class \"{}\" could not be loaded",DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY, s, t); - } - } - - if(list.size() > 0){ - try { - NeuralNetConfiguration.registerLegacyCustomClassesForJSONList(list); - } catch (Throwable t){ - log.warn("Error registering custom classes for legacy JSON deserialization ({} system property)",DL4JSystemProperties.CUSTOM_REGISTRATION_PROPERTY, t); - } - } - } - } - private static ObjectMapper jsonMapper = new ObjectMapper(); private static ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory()); - /* - Note to developers: The following JSON mappers are for handling legacy format JSON. - Note that after 1.0.0-alpha, the JSON subtype format for layers, preprocessors, graph vertices, - etc were changed from a wrapper object, to an "@class" field. - However, in an attempt to not break saved networks, these mappers are part of the solution. - - How legacy loading works (same pattern for all types - Layer, GraphVertex, InputPreprocesor etc) - 1. Layers etc that have an "@class" field are deserialized as normal - 2. Layers that don't have such a field are mapped (via Layer @JsonTypeInfo) to the LegacyLayerDeserializerHelper class. - 3. LegacyLayerDeserializerHelper has a @JsonDeserialize annotation - we use LegacyLayerDeserialize to handle it - 4. LegacyLayerDeserializer has a list of old names (present in the legacy format JSON) and the corresponding class names - 5. BaseLegacyDeserializer (that LegacyLayerDeserializer extends) does a lookup and handles the deserialization - - Now, as to why we have one ObjectMapper for each type: We can't use the default JSON mapper for the legacy format, - as it'll fail due to not having the expected "@class" annotation. - Consequently, we need to tell Jackson to ignore that specific annotation and deserialize to the specified - class anyway. The ignoring is done via an annotation introspector, defined below in this class. - However, we can't just use a single annotation introspector (and hence ObjectMapper) for loading legacy values of - all types - if we did, then any nested types would fail (i.e., an IActivation in a Layer - the IActivation couldn't - be deserialized correctly, as the annotation would be ignored). - - */ - @Getter - private static ObjectMapper jsonMapperLegacyFormatLayer = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatVertex = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatPreproc = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatIActivation = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatILoss = new ObjectMapper(); - @Getter - private static ObjectMapper jsonMapperLegacyFormatReconstruction = new ObjectMapper(); + private static ObjectMapper legacyMapper; static { configureMapper(jsonMapper); configureMapper(yamlMapper); - configureMapper(jsonMapperLegacyFormatLayer); - configureMapper(jsonMapperLegacyFormatVertex); - configureMapper(jsonMapperLegacyFormatPreproc); - configureMapper(jsonMapperLegacyFormatIActivation); - configureMapper(jsonMapperLegacyFormatILoss); - configureMapper(jsonMapperLegacyFormatReconstruction); - - jsonMapperLegacyFormatLayer.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(Layer.class))); - jsonMapperLegacyFormatVertex.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(GraphVertex.class))); - jsonMapperLegacyFormatPreproc.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(InputPreProcessor.class))); - jsonMapperLegacyFormatIActivation.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(IActivation.class))); - jsonMapperLegacyFormatILoss.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(ILossFunction.class))); - jsonMapperLegacyFormatReconstruction.setAnnotationIntrospector(new IgnoreJsonTypeInfoIntrospector(Collections.singletonList(ReconstructionDistribution.class))); - - LegacyIActivationDeserializer.setLegacyJsonMapper(jsonMapperLegacyFormatIActivation); - LegacyILossFunctionDeserializer.setLegacyJsonMapper(jsonMapperLegacyFormatILoss); } /** @@ -152,6 +76,14 @@ public class JsonMappers { return jsonMapper; } + public static synchronized ObjectMapper getLegacyMapper(){ + if(legacyMapper == null){ + legacyMapper = LegacyJsonFormat.getMapper100alpha(); + configureMapper(legacyMapper); + } + return legacyMapper; + } + /** * @return The default/primary ObjectMapper for deserializing network configurations in DL4J (YAML format) */ @@ -182,60 +114,4 @@ public class JsonMappers { ret.registerModule(customDeserializerModule); } - - - /** - * Custom Jackson Introspector to ignore the {@code @JsonTypeYnfo} annotations on layers etc. - * This is so we can deserialize legacy format JSON without recursing infinitely, by selectively ignoring - * a set of JsonTypeInfo annotations - */ - @AllArgsConstructor - private static class IgnoreJsonTypeInfoIntrospector extends JacksonAnnotationIntrospector { - - private List classList; - - @Override - protected TypeResolverBuilder _findTypeResolver(MapperConfig config, Annotated ann, JavaType baseType) { - if(ann instanceof AnnotatedClass){ - AnnotatedClass c = (AnnotatedClass)ann; - Class annClass = c.getAnnotated(); - - boolean isAssignable = false; - for(Class c2 : classList){ - if(c2.isAssignableFrom(annClass)){ - isAssignable = true; - break; - } - } - - if( isAssignable ){ - AnnotationMap annotations = (AnnotationMap) ((AnnotatedClass) ann).getAnnotations(); - if(annotations == null || annotations.annotations() == null){ - //Probably not necessary - but here for safety - return super._findTypeResolver(config, ann, baseType); - } - - AnnotationMap newMap = null; - for(Annotation a : annotations.annotations()){ - Class annType = a.annotationType(); - if(annType == JsonTypeInfo.class){ - //Ignore the JsonTypeInfo annotation on the Layer class - continue; - } - if(newMap == null){ - newMap = new AnnotationMap(); - } - newMap.add(a); - } - if(newMap == null) - return null; - - //Pass the remaining annotations (if any) to the original introspector - AnnotatedClass ann2 = c.withAnnotations(newMap); - return super._findTypeResolver(config, ann2, baseType); - } - } - return super._findTypeResolver(config, ann, baseType); - } - } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java index e7bcdc39a..028fef9d3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; @@ -59,6 +60,8 @@ public class MultiLayerConfigurationDeserializer extends BaseNetConfigDeserializ boolean requiresLegacyRegularizationHandling = requiresRegularizationFromLegacy(layers); boolean requiresLegacyWeightInitHandling = requiresWeightInitFromLegacy(layers); + boolean requiresLegacyActivationHandling = requiresActivationFromLegacy(layers); + boolean requiresLegacyLossHandling = requiresLegacyLossHandling(layers); if(attemptIUpdaterFromLegacy || requiresLegacyRegularizationHandling || requiresLegacyWeightInitHandling) { JsonLocation endLocation = jp.getCurrentLocation(); @@ -115,38 +118,34 @@ public class MultiLayerConfigurationDeserializer extends BaseNetConfigDeserializ } } - if(requiresLegacyRegularizationHandling) { - if (layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getRegularization() == null) { - if(on.has("layer")){ - //Legacy format - ObjectNode layerNode = (ObjectNode)on.get("layer"); - if(layerNode.has("@class")){ - //Later legacy format: class field for JSON subclass - on = layerNode; - } else { - //Early legacy format: wrapper object for JSON subclass - on = (ObjectNode) on.get("layer").elements().next(); - } + if(requiresLegacyRegularizationHandling || requiresLegacyWeightInitHandling || requiresLegacyActivationHandling){ + if(on.has("layer")){ + //Legacy format + ObjectNode layerNode = (ObjectNode)on.get("layer"); + if(layerNode.has("@class")){ + //Later legacy format: class field for JSON subclass + on = layerNode; + } else { + //Early legacy format: wrapper object for JSON subclass + on = (ObjectNode) on.get("layer").elements().next(); } - handleL1L2BackwardCompatibility((BaseLayer) layers[i], on); } } - if(requiresLegacyWeightInitHandling){ - if (layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getWeightInitFn() == null) { - if(on.has("layer")){ - //Legacy format - ObjectNode layerNode = (ObjectNode)on.get("layer"); - if(layerNode.has("@class")){ - //Later legacy format: class field for JSON subclass - on = layerNode; - } else { - //Early legacy format: wrapper object for JSON subclass - on = (ObjectNode) on.get("layer").elements().next(); - } - } - handleWeightInitBackwardCompatibility((BaseLayer) layers[i], on); - } + if(requiresLegacyRegularizationHandling && layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getRegularization() == null) { + handleL1L2BackwardCompatibility((BaseLayer) layers[i], on); + } + + if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getWeightInitFn() == null) { + handleWeightInitBackwardCompatibility((BaseLayer) layers[i], on); + } + + if(requiresLegacyActivationHandling && layers[i] instanceof BaseLayer && ((BaseLayer)layers[i]).getActivationFn() == null){ + handleActivationBackwardCompatibility((BaseLayer) layers[i], on); + } + + if(requiresLegacyLossHandling && layers[i] instanceof BaseOutputLayer && ((BaseOutputLayer)layers[i]).getLossFn() == null){ + handleLossBackwardCompatibility((BaseOutputLayer) layers[i], on); } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyIntArrayDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java similarity index 97% rename from deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyIntArrayDeserializer.java rename to deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java index e080d6a78..064219fd1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyIntArrayDeserializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.nn.conf.serde.legacyformat; +package org.deeplearning4j.nn.conf.serde.legacy; import org.nd4j.shade.jackson.core.JsonParser; import org.nd4j.shade.jackson.core.JsonProcessingException; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java new file mode 100644 index 000000000..e421c4b1f --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java @@ -0,0 +1,175 @@ +package org.deeplearning4j.nn.conf.serde.legacy; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.graph.*; +import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; +import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; +import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; +import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.conf.layers.util.MaskLayer; +import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.conf.layers.variational.*; +import org.deeplearning4j.nn.conf.preprocessor.*; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.linalg.lossfunctions.impl.*; +import org.nd4j.shade.jackson.annotation.JsonSubTypes; +import org.nd4j.shade.jackson.annotation.JsonTypeInfo; +import org.nd4j.shade.jackson.databind.ObjectMapper; + +/** + * This class defines a set of Jackson Mixins - which are a way of using a proxy class with annotations to override + * the existing annotations. + * In 1.0.0-beta, we switched how subtypes were handled in JSON ser/de: from "wrapper object" to "@class field". + * We use these mixins to allow us to still load the old format + * + * @author Alex Black + */ +public class LegacyJsonFormat { + + private LegacyJsonFormat(){ } + + /** + * Get a mapper (minus general config) suitable for loading old format JSON - 1.0.0-alpha and before + * @return Object mapper + */ + public static ObjectMapper getMapper100alpha(){ + //After 1.0.0-alpha, we switched from wrapper object to @class for subtype information + ObjectMapper om = new ObjectMapper(); + + om.addMixIn(InputPreProcessor.class, InputPreProcessorMixin.class); + om.addMixIn(GraphVertex.class, GraphVertexMixin.class); + om.addMixIn(Layer.class, LayerMixin.class); + om.addMixIn(ReconstructionDistribution.class, ReconstructionDistributionMixin.class); + om.addMixIn(IActivation.class, IActivationMixin.class); + om.addMixIn(ILossFunction.class, ILossFunctionMixin.class); + + return om; + } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = CnnToFeedForwardPreProcessor.class, name = "cnnToFeedForward"), + @JsonSubTypes.Type(value = CnnToRnnPreProcessor.class, name = "cnnToRnn"), + @JsonSubTypes.Type(value = ComposableInputPreProcessor.class, name = "composableInput"), + @JsonSubTypes.Type(value = FeedForwardToCnnPreProcessor.class, name = "feedForwardToCnn"), + @JsonSubTypes.Type(value = FeedForwardToRnnPreProcessor.class, name = "feedForwardToRnn"), + @JsonSubTypes.Type(value = RnnToFeedForwardPreProcessor.class, name = "rnnToFeedForward"), + @JsonSubTypes.Type(value = RnnToCnnPreProcessor.class, name = "rnnToCnn")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class InputPreProcessorMixin { } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ElementWiseVertex.class, name = "ElementWiseVertex"), + @JsonSubTypes.Type(value = MergeVertex.class, name = "MergeVertex"), + @JsonSubTypes.Type(value = SubsetVertex.class, name = "SubsetVertex"), + @JsonSubTypes.Type(value = LayerVertex.class, name = "LayerVertex"), + @JsonSubTypes.Type(value = LastTimeStepVertex.class, name = "LastTimeStepVertex"), + @JsonSubTypes.Type(value = ReverseTimeSeriesVertex.class, name = "ReverseTimeSeriesVertex"), + @JsonSubTypes.Type(value = DuplicateToTimeSeriesVertex.class, name = "DuplicateToTimeSeriesVertex"), + @JsonSubTypes.Type(value = PreprocessorVertex.class, name = "PreprocessorVertex"), + @JsonSubTypes.Type(value = StackVertex.class, name = "StackVertex"), + @JsonSubTypes.Type(value = UnstackVertex.class, name = "UnstackVertex"), + @JsonSubTypes.Type(value = L2Vertex.class, name = "L2Vertex"), + @JsonSubTypes.Type(value = ScaleVertex.class, name = "ScaleVertex"), + @JsonSubTypes.Type(value = L2NormalizeVertex.class, name = "L2NormalizeVertex")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class GraphVertexMixin{ } + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = AutoEncoder.class, name = "autoEncoder"), + @JsonSubTypes.Type(value = ConvolutionLayer.class, name = "convolution"), + @JsonSubTypes.Type(value = Convolution1DLayer.class, name = "convolution1d"), + @JsonSubTypes.Type(value = GravesLSTM.class, name = "gravesLSTM"), + @JsonSubTypes.Type(value = LSTM.class, name = "LSTM"), + @JsonSubTypes.Type(value = GravesBidirectionalLSTM.class, name = "gravesBidirectionalLSTM"), + @JsonSubTypes.Type(value = OutputLayer.class, name = "output"), + @JsonSubTypes.Type(value = CenterLossOutputLayer.class, name = "CenterLossOutputLayer"), + @JsonSubTypes.Type(value = RnnOutputLayer.class, name = "rnnoutput"), + @JsonSubTypes.Type(value = LossLayer.class, name = "loss"), + @JsonSubTypes.Type(value = DenseLayer.class, name = "dense"), + @JsonSubTypes.Type(value = SubsamplingLayer.class, name = "subsampling"), + @JsonSubTypes.Type(value = Subsampling1DLayer.class, name = "subsampling1d"), + @JsonSubTypes.Type(value = BatchNormalization.class, name = "batchNormalization"), + @JsonSubTypes.Type(value = LocalResponseNormalization.class, name = "localResponseNormalization"), + @JsonSubTypes.Type(value = EmbeddingLayer.class, name = "embedding"), + @JsonSubTypes.Type(value = ActivationLayer.class, name = "activation"), + @JsonSubTypes.Type(value = VariationalAutoencoder.class, name = "VariationalAutoencoder"), + @JsonSubTypes.Type(value = DropoutLayer.class, name = "dropout"), + @JsonSubTypes.Type(value = GlobalPoolingLayer.class, name = "GlobalPooling"), + @JsonSubTypes.Type(value = ZeroPaddingLayer.class, name = "zeroPadding"), + @JsonSubTypes.Type(value = ZeroPadding1DLayer.class, name = "zeroPadding1d"), + @JsonSubTypes.Type(value = FrozenLayer.class, name = "FrozenLayer"), + @JsonSubTypes.Type(value = Upsampling2D.class, name = "Upsampling2D"), + @JsonSubTypes.Type(value = Yolo2OutputLayer.class, name = "Yolo2OutputLayer"), + @JsonSubTypes.Type(value = RnnLossLayer.class, name = "RnnLossLayer"), + @JsonSubTypes.Type(value = CnnLossLayer.class, name = "CnnLossLayer"), + @JsonSubTypes.Type(value = Bidirectional.class, name = "Bidirectional"), + @JsonSubTypes.Type(value = SimpleRnn.class, name = "SimpleRnn"), + @JsonSubTypes.Type(value = ElementWiseMultiplicationLayer.class, name = "ElementWiseMult"), + @JsonSubTypes.Type(value = MaskLayer.class, name = "MaskLayer"), + @JsonSubTypes.Type(value = MaskZeroLayer.class, name = "MaskZeroLayer"), + @JsonSubTypes.Type(value = Cropping1D.class, name = "Cropping1D"), + @JsonSubTypes.Type(value = Cropping2D.class, name = "Cropping2D")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class LayerMixin {} + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = GaussianReconstructionDistribution.class, name = "Gaussian"), + @JsonSubTypes.Type(value = BernoulliReconstructionDistribution.class, name = "Bernoulli"), + @JsonSubTypes.Type(value = ExponentialReconstructionDistribution.class, name = "Exponential"), + @JsonSubTypes.Type(value = CompositeReconstructionDistribution.class, name = "Composite"), + @JsonSubTypes.Type(value = LossFunctionWrapper.class, name = "LossWrapper")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ReconstructionDistributionMixin {} + + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = ActivationCube.class, name = "Cube"), + @JsonSubTypes.Type(value = ActivationELU.class, name = "ELU"), + @JsonSubTypes.Type(value = ActivationHardSigmoid.class, name = "HardSigmoid"), + @JsonSubTypes.Type(value = ActivationHardTanH.class, name = "HardTanh"), + @JsonSubTypes.Type(value = ActivationIdentity.class, name = "Identity"), + @JsonSubTypes.Type(value = ActivationLReLU.class, name = "LReLU"), + @JsonSubTypes.Type(value = ActivationRationalTanh.class, name = "RationalTanh"), + @JsonSubTypes.Type(value = ActivationRectifiedTanh.class, name = "RectifiedTanh"), + @JsonSubTypes.Type(value = ActivationSELU.class, name = "SELU"), + @JsonSubTypes.Type(value = ActivationSwish.class, name = "SWISH"), + @JsonSubTypes.Type(value = ActivationReLU.class, name = "ReLU"), + @JsonSubTypes.Type(value = ActivationRReLU.class, name = "RReLU"), + @JsonSubTypes.Type(value = ActivationSigmoid.class, name = "Sigmoid"), + @JsonSubTypes.Type(value = ActivationSoftmax.class, name = "Softmax"), + @JsonSubTypes.Type(value = ActivationSoftPlus.class, name = "SoftPlus"), + @JsonSubTypes.Type(value = ActivationSoftSign.class, name = "SoftSign"), + @JsonSubTypes.Type(value = ActivationTanH.class, name = "TanH")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class IActivationMixin {} + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.WRAPPER_OBJECT) + @JsonSubTypes(value = {@JsonSubTypes.Type(value = LossBinaryXENT.class, name = "BinaryXENT"), + @JsonSubTypes.Type(value = LossCosineProximity.class, name = "CosineProximity"), + @JsonSubTypes.Type(value = LossHinge.class, name = "Hinge"), + @JsonSubTypes.Type(value = LossKLD.class, name = "KLD"), + @JsonSubTypes.Type(value = LossMAE.class, name = "MAE"), + @JsonSubTypes.Type(value = LossL1.class, name = "L1"), + @JsonSubTypes.Type(value = LossMAPE.class, name = "MAPE"), + @JsonSubTypes.Type(value = LossMCXENT.class, name = "MCXENT"), + @JsonSubTypes.Type(value = LossMSE.class, name = "MSE"), + @JsonSubTypes.Type(value = LossL2.class, name = "L2"), + @JsonSubTypes.Type(value = LossMSLE.class, name = "MSLE"), + @JsonSubTypes.Type(value = LossNegativeLogLikelihood.class, name = "NegativeLogLikelihood"), + @JsonSubTypes.Type(value = LossPoisson.class, name = "Poisson"), + @JsonSubTypes.Type(value = LossSquaredHinge.class, name = "SquaredHinge"), + @JsonSubTypes.Type(value = LossFMeasure.class, name = "FMeasure")}) + @NoArgsConstructor(access = AccessLevel.PRIVATE) + public static class ILossFunctionMixin {} +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializer.java deleted file mode 100644 index 822d8fc3d..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializer.java +++ /dev/null @@ -1,94 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.graph.*; -import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; -import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; -import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * Deserializer for GraphVertex JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyGraphVertexDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - - - List> cList = Arrays.asList( - //All of these vertices had the legacy format name the same as the simple class name - MergeVertex.class, - SubsetVertex.class, - LayerVertex.class, - LastTimeStepVertex.class, - ReverseTimeSeriesVertex.class, - DuplicateToTimeSeriesVertex.class, - PreprocessorVertex.class, - StackVertex.class, - UnstackVertex.class, - L2Vertex.class, - ScaleVertex.class, - L2NormalizeVertex.class, - //These did not previously have a subtype annotation - they use default (which is simple class name) - ElementWiseVertex.class, - PoolHelperVertex.class, - ReshapeVertex.class, - ShiftVertex.class); - - for(Class c : cList){ - LEGACY_NAMES.put(c.getSimpleName(), c.getName()); - } - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { -// return JsonMappers.getMapperLegacyJson(); - return JsonMappers.getJsonMapperLegacyFormatVertex(); - } - - @Override - public Class getDeserializedType() { - return GraphVertex.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializerHelper.java deleted file mode 100644 index bf9a2654a..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyGraphVertexDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyGraphVertexDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.graph.GraphVertex} - */ -@JsonDeserialize(using = LegacyGraphVertexDeserializer.class) -public class LegacyGraphVertexDeserializerHelper { - private LegacyGraphVertexDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializer.java deleted file mode 100644 index 9111dff87..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializer.java +++ /dev/null @@ -1,113 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; -import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; -import org.deeplearning4j.nn.conf.layers.misc.ElementWiseMultiplicationLayer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; -import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; -import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; -import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; -import org.deeplearning4j.nn.conf.layers.util.MaskLayer; -import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; -import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.HashMap; -import java.util.Map; - -/** - * Deserializer for Layer JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyLayerDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - LEGACY_NAMES.put("autoEncoder", AutoEncoder.class.getName()); - LEGACY_NAMES.put("convolution", ConvolutionLayer.class.getName()); - LEGACY_NAMES.put("convolution1d", Convolution1DLayer.class.getName()); - LEGACY_NAMES.put("gravesLSTM", GravesLSTM.class.getName()); - LEGACY_NAMES.put("LSTM", LSTM.class.getName()); - LEGACY_NAMES.put("gravesBidirectionalLSTM", GravesBidirectionalLSTM.class.getName()); - LEGACY_NAMES.put("output", OutputLayer.class.getName()); - LEGACY_NAMES.put("CenterLossOutputLayer", CenterLossOutputLayer.class.getName()); - LEGACY_NAMES.put("rnnoutput", RnnOutputLayer.class.getName()); - LEGACY_NAMES.put("loss", LossLayer.class.getName()); - LEGACY_NAMES.put("dense", DenseLayer.class.getName()); - LEGACY_NAMES.put("subsampling", SubsamplingLayer.class.getName()); - LEGACY_NAMES.put("subsampling1d", Subsampling1DLayer.class.getName()); - LEGACY_NAMES.put("batchNormalization", BatchNormalization.class.getName()); - LEGACY_NAMES.put("localResponseNormalization", LocalResponseNormalization.class.getName()); - LEGACY_NAMES.put("embedding", EmbeddingLayer.class.getName()); - LEGACY_NAMES.put("activation", ActivationLayer.class.getName()); - LEGACY_NAMES.put("VariationalAutoencoder", VariationalAutoencoder.class.getName()); - LEGACY_NAMES.put("dropout", DropoutLayer.class.getName()); - LEGACY_NAMES.put("GlobalPooling", GlobalPoolingLayer.class.getName()); - LEGACY_NAMES.put("zeroPadding", ZeroPaddingLayer.class.getName()); - LEGACY_NAMES.put("zeroPadding1d", ZeroPadding1DLayer.class.getName()); - LEGACY_NAMES.put("FrozenLayer", FrozenLayer.class.getName()); - LEGACY_NAMES.put("Upsampling2D", Upsampling2D.class.getName()); - LEGACY_NAMES.put("Yolo2OutputLayer", Yolo2OutputLayer.class.getName()); - LEGACY_NAMES.put("RnnLossLayer", RnnLossLayer.class.getName()); - LEGACY_NAMES.put("CnnLossLayer", CnnLossLayer.class.getName()); - LEGACY_NAMES.put("Bidirectional", Bidirectional.class.getName()); - LEGACY_NAMES.put("SimpleRnn", SimpleRnn.class.getName()); - LEGACY_NAMES.put("ElementWiseMult", ElementWiseMultiplicationLayer.class.getName()); - LEGACY_NAMES.put("MaskLayer", MaskLayer.class.getName()); - LEGACY_NAMES.put("MaskZeroLayer", MaskZeroLayer.class.getName()); - LEGACY_NAMES.put("Cropping1D", Cropping1D.class.getName()); - LEGACY_NAMES.put("Cropping2D", Cropping2D.class.getName()); - - //The following didn't previously have subtype annotations - hence will be using default name (class simple name) - LEGACY_NAMES.put("LastTimeStep", LastTimeStep.class.getName()); - LEGACY_NAMES.put("SpaceToDepthLayer", SpaceToDepthLayer.class.getName()); - LEGACY_NAMES.put("SpaceToBatchLayer", SpaceToBatchLayer.class.getName()); - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { - return JsonMappers.getJsonMapperLegacyFormatLayer(); - } - - @Override - public Class getDeserializedType() { - return Layer.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializerHelper.java deleted file mode 100644 index 76f7f7d7d..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyLayerDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyLayerDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.layers.Layer} - */ -@JsonDeserialize(using = LegacyLayerDeserializer.class) -public class LegacyLayerDeserializerHelper { - private LegacyLayerDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializer.java deleted file mode 100644 index bed1c1c5c..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializer.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.graph.*; -import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; -import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; -import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; -import org.deeplearning4j.nn.conf.preprocessor.*; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * Deserializer for InputPreProcessor JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyPreprocessorDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - LEGACY_NAMES.put("cnnToFeedForward", CnnToFeedForwardPreProcessor.class.getName()); - LEGACY_NAMES.put("cnnToRnn", CnnToRnnPreProcessor.class.getName()); - LEGACY_NAMES.put("composableInput", ComposableInputPreProcessor.class.getName()); - LEGACY_NAMES.put("feedForwardToCnn", FeedForwardToCnnPreProcessor.class.getName()); - LEGACY_NAMES.put("feedForwardToRnn", FeedForwardToRnnPreProcessor.class.getName()); - LEGACY_NAMES.put("rnnToFeedForward", RnnToFeedForwardPreProcessor.class.getName()); - LEGACY_NAMES.put("rnnToCnn", RnnToCnnPreProcessor.class.getName()); - - //Keras preprocessors: they defaulted to class simple name - LEGACY_NAMES.put("KerasFlattenRnnPreprocessor","org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor"); - LEGACY_NAMES.put("ReshapePreprocessor","org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor"); - LEGACY_NAMES.put("TensorFlowCnnToFeedForwardPreProcessor","org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor"); - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { -// return JsonMappers.getMapperLegacyJson(); - return JsonMappers.getJsonMapperLegacyFormatPreproc(); - } - - @Override - public Class getDeserializedType() { - return InputPreProcessor.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializerHelper.java deleted file mode 100644 index 19300ba5f..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyPreprocessorDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyPreprocessorDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.InputPreProcessor} - */ -@JsonDeserialize(using = LegacyPreprocessorDeserializer.class) -public class LegacyPreprocessorDeserializerHelper { - private LegacyPreprocessorDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializer.java deleted file mode 100644 index 06cf37f7e..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializer.java +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import lombok.NonNull; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.layers.variational.*; -import org.deeplearning4j.nn.conf.preprocessor.*; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.nd4j.serde.json.BaseLegacyDeserializer; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.util.HashMap; -import java.util.Map; - -/** - * Deserializer for ReconstructionDistribution JSON in legacy format - see {@link BaseLegacyDeserializer} - * - * @author Alex Black - */ -public class LegacyReconstructionDistributionDeserializer extends BaseLegacyDeserializer { - - private static final Map LEGACY_NAMES = new HashMap<>(); - - static { - LEGACY_NAMES.put("Gaussian", GaussianReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("Bernoulli", BernoulliReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("Exponential", ExponentialReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("Composite", CompositeReconstructionDistribution.class.getName()); - LEGACY_NAMES.put("LossWrapper", LossFunctionWrapper.class.getName()); - } - - - @Override - public Map getLegacyNamesMap() { - return LEGACY_NAMES; - } - - @Override - public ObjectMapper getLegacyJsonMapper() { - return JsonMappers.getJsonMapperLegacyFormatReconstruction(); - } - - @Override - public Class getDeserializedType() { - return ReconstructionDistribution.class; - } - - public static void registerLegacyClassDefaultName(@NonNull Class clazz){ - registerLegacyClassSpecifiedName(clazz.getSimpleName(), clazz); - } - - public static void registerLegacyClassSpecifiedName(@NonNull String name, @NonNull Class clazz){ - LEGACY_NAMES.put(name, clazz.getName()); - } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializerHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializerHelper.java deleted file mode 100644 index 61952d62e..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacyformat/LegacyReconstructionDistributionDeserializerHelper.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.conf.serde.legacyformat; - -import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; - -/** - * Simple helper class to redirect legacy JSON format to {@link LegacyReconstructionDistributionDeserializer} via annotation - * on {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution} - */ -@JsonDeserialize(using = LegacyReconstructionDistributionDeserializer.class) -public class LegacyReconstructionDistributionDeserializerHelper { - private LegacyReconstructionDistributionDeserializerHelper(){ } -} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 00c0cf7d6..7c292bafa 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -2460,6 +2460,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); + if(t != null){ + if(t instanceof RuntimeException){ + throw ((RuntimeException)t); + } + throw new RuntimeException("Error during neural network forward pass", t); + } + if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached"); } else { @@ -2780,6 +2787,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); + + if(t != null){ + if(t instanceof RuntimeException){ + throw ((RuntimeException)t); + } + throw new RuntimeException("Error during neural network backpropagation calculation", t); + } } //Now, add the gradients in the order we need them in for flattening (same as params order) @@ -3312,8 +3326,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { @Override public int batchSize() { + //In 99+% of cases, the input and labels dimension 0 size should be identical + //The only real exceptions: space to batch, and batch to space layers + //In those cases, we should base it on the labels size, as this impacts gradient calculation // FIXME: int cast - return (int) inputs[0].size(0); + return labels == null || labels[0] == null ? (int) inputs[0].size(0) : (int)labels[0].size(0); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index 405889f0e..00ca7e7c4 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -93,7 +93,7 @@ public abstract class BaseLayer tBpttStateMap = new ConcurrentHashMap<>(); + protected int helperCountFail = 0; + public BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java index 78e15e167..6fc96dc80 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java @@ -17,7 +17,6 @@ package org.deeplearning4j.nn.layers.recurrent; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -90,7 +89,8 @@ public class GravesBidirectionalLSTM final FwdPassReturn fwdPass = activateHelperDirectional(true, null, null, true, true, workspaceMgr); - final Pair forwardsGradient = LSTMHelpers.backpropGradientHelper(this.conf, + final Pair forwardsGradient = LSTMHelpers.backpropGradientHelper(this, + this.conf, this.layerConf().getGateActivationFn(), this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), epsilon, @@ -98,13 +98,14 @@ public class GravesBidirectionalLSTM GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS, gradientViews, maskArray, true, - null, workspaceMgr); + null, workspaceMgr, layerConf().isHelperAllowFallback()); final FwdPassReturn backPass = activateHelperDirectional(true, null, null, true, false, workspaceMgr); - final Pair backwardsGradient = LSTMHelpers.backpropGradientHelper(this.conf, + final Pair backwardsGradient = LSTMHelpers.backpropGradientHelper(this, + this.conf, this.layerConf().getGateActivationFn(), this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), epsilon, @@ -112,7 +113,7 @@ public class GravesBidirectionalLSTM GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, gradientViews, maskArray, true, - null, workspaceMgr); + null, workspaceMgr, layerConf().isHelperAllowFallback()); //merge the gradient, which is key value pair of String,INDArray @@ -175,7 +176,7 @@ public class GravesBidirectionalLSTM getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), training, null, null, forBackprop || (cacheMode != CacheMode.NONE && training), true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, maskArray, true, null, - forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr); + forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); backwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), this.input, @@ -184,7 +185,7 @@ public class GravesBidirectionalLSTM getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS), training, null, null, forBackprop || (cacheMode != CacheMode.NONE && training), false, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, maskArray, true, null, - forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr); + forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); cachedPassForward = forwardsEval; cachedPassBackward = backwardsEval; @@ -230,7 +231,7 @@ public class GravesBidirectionalLSTM return LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), this.input, getParam(recurrentKey), getParam(inputKey), getParam(biasKey), training, prevOutputActivations, prevMemCellState, forBackprop, forwards, inputKey, maskArray, true, - null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr); + null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java index a2f38b324..13f30b8bb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java @@ -17,7 +17,6 @@ package org.deeplearning4j.nn.layers.recurrent; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -92,11 +91,12 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this.conf, this.layerConf().getGateActivationFn(), this.input, + Pair p = LSTMHelpers.backpropGradientHelper(this, + this.conf, this.layerConf().getGateActivationFn(), this.input, recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, GravesLSTMParamInitializer.BIAS_KEY, gradientViews, maskArray, true, null, - workspaceMgr); + workspaceMgr, layerConf().isHelperAllowFallback()); weightNoiseParams.clear(); p.setSecond(backpropDropOutIfPresent(p.getSecond())); @@ -141,7 +141,7 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this.conf, this.layerConf().getGateActivationFn(), this.input, + Pair p = LSTMHelpers.backpropGradientHelper(this, + this.conf, this.layerConf().getGateActivationFn(), this.input, recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true, LSTMParamInitializer.INPUT_WEIGHT_KEY, LSTMParamInitializer.RECURRENT_WEIGHT_KEY, - LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr); + LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr, + layerConf().isHelperAllowFallback()); weightNoiseParams.clear(); p.setSecond(backpropDropOutIfPresent(p.getSecond())); @@ -161,7 +160,7 @@ public class LSTM extends BaseRecurrentLayer backpropGradientHelper(final NeuralNetConfiguration conf, + static public Pair backpropGradientHelper(final BaseRecurrentLayer layer, final NeuralNetConfiguration conf, final IActivation gateActivationFn, INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG] final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg] final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength, @@ -433,7 +451,8 @@ public class LSTMHelpers { final Map gradientViews, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM final LSTMHelper helper, - final LayerWorkspaceMgr workspaceMgr) { + final LayerWorkspaceMgr workspaceMgr, + final boolean isHelperAllowFallback) { input = input.castTo(inputWeights.dataType()); //No-op if @@ -496,11 +515,29 @@ public class LSTMHelpers { rwGradientsGG = rwGradientsOut.get(all(), NDArrayIndex.point(4 * hiddenLayerSize + 2)).reshape(1, recurrentWeights.size(0)); } - if (helper != null) { - Pair ret = helper.backpropGradient(conf, gateActivationFn, input, recurrentWeights, - inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, forwards, - inputWeightKey, recurrentWeightKey, biasWeightKey, gradientViews, maskArray, - hasPeepholeConnections, workspaceMgr); + if (helper != null && (layer.helperCountFail == 0 || !isHelperAllowFallback)) { + Pair ret = null; + try { + ret = helper.backpropGradient(conf, gateActivationFn, input, recurrentWeights, + inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, forwards, + inputWeightKey, recurrentWeightKey, biasWeightKey, gradientViews, maskArray, + hasPeepholeConnections, workspaceMgr); + }catch (ND4JOpProfilerException e){ + throw e; //NaN panic etc for debugging + } catch (Exception e){ + if(e.getMessage().contains("Failed to allocate")){ + //This is a memory exception - don't fallback to built-in implementation + throw e; + } + + if(isHelperAllowFallback){ + layer.helperCountFail++; + log.warn("MKL/CuDNN execution failed - falling back on built-in implementation",e); + } else { + throw new RuntimeException("Error during LSTM MKL/CuDNN helper backprop - helperAllowFallback() is set to false", e); + } + } + if (ret != null) { return ret; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java index 5704e57d3..044f444c0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java @@ -158,7 +158,7 @@ public class SimpleRnn extends BaseRecurrentLayer end){ dldzNext = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, dldzCurrent.dataType(), dldzCurrent.shape()); INDArray ggCur = workspaceMgr.createUninitialized(ArrayType.BP_WORKING_MEM, gg.dataType(), grg.shape()); - Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, 1)); + Nd4j.getExecutioner().exec(new LayerNormBp(rCurrent, gr, dldzCurrent, dldzNext, ggCur, true, 1)); grg.addi(ggCur); }else{ dldzNext = dldzCurrent; @@ -217,6 +217,9 @@ public class SimpleRnn extends BaseRecurrentLayer test-nd4j-native - true + false diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml index 4a8b6f230..0b6b05c26 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/pom.xml @@ -38,13 +38,6 @@ - - 2.1.0 - 2 - 2.11.12 2.11 @@ -102,11 +95,6 @@ jsch ${jsch.version} - - com.google.guava - guava - ${guava.version} - com.google.inject guice @@ -152,21 +140,6 @@ jaxb-impl ${jaxb.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-remote_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - io.netty netty diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml index 5d0d1403d..eaf0e6be8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml @@ -51,12 +51,6 @@ deeplearning4j-parallel-wrapper ${project.version} - - org.nd4j - nd4j-native - ${project.version} - test - org.nd4j nd4j-parameter-server-client @@ -78,27 +72,11 @@ test - - - io.netty - netty - ${netty.version} - org.scala-lang scala-library ${scala.version} - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} - - - com.typesafe - config - ${typesafe.config.version} - ch.qos.logback @@ -110,10 +88,26 @@ test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + - diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java index 1af2f4600..4ba5ba4ce 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java @@ -16,7 +16,7 @@ package org.deeplearning4j.parallelism; -import com.google.common.util.concurrent.AtomicDouble; +import org.nd4j.shade.guava.util.concurrent.AtomicDouble; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingResult; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java index 4c9520ecd..432e22695 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObservable.java @@ -16,7 +16,7 @@ package org.deeplearning4j.parallelism.inference.observers; -import com.google.common.base.Preconditions; +import org.nd4j.shade.guava.base.Preconditions; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml index 16e0176a8..3fded3e4a 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml @@ -32,42 +32,6 @@ 3.4.2 - - - - - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.module - jackson-module-scala_${scala.binary.version} - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${spark2.jackson.version} - - - - org.deeplearning4j diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java index 9164a41d0..d14668154 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.models.sequencevectors.functions; import lombok.NonNull; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; import org.deeplearning4j.models.sequencevectors.sequence.Sequence; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; @@ -31,64 +30,56 @@ import org.nd4j.parameterserver.distributed.training.TrainingDriver; import org.nd4j.parameterserver.distributed.transport.RoutedTransport; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; /** * @author raver119@gmail.com */ -public class VocabRddFunctionFlat extends BaseFlatMapFunctionAdaptee, T> { +public class VocabRddFunctionFlat implements FlatMapFunction, T> { + protected Broadcast vectorsConfigurationBroadcast; + protected Broadcast paramServerConfigurationBroadcast; + + protected transient VectorsConfiguration configuration; + protected transient SparkElementsLearningAlgorithm ela; + protected transient TrainingDriver driver; + public VocabRddFunctionFlat(@NonNull Broadcast vectorsConfigurationBroadcast, - @NonNull Broadcast paramServerConfigurationBroadcast) { - super(new VocabRddFunctionAdapter(vectorsConfigurationBroadcast, paramServerConfigurationBroadcast)); + @NonNull Broadcast paramServerConfigurationBroadcast) { + this.vectorsConfigurationBroadcast = vectorsConfigurationBroadcast; + this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast; } + @Override + public Iterator call(Sequence sequence) throws Exception { + if (configuration == null) + configuration = vectorsConfigurationBroadcast.getValue(); - private static class VocabRddFunctionAdapter - implements FlatMapFunctionAdapter, T> { - protected Broadcast vectorsConfigurationBroadcast; - protected Broadcast paramServerConfigurationBroadcast; - - protected transient VectorsConfiguration configuration; - protected transient SparkElementsLearningAlgorithm ela; - protected transient TrainingDriver driver; - - public VocabRddFunctionAdapter(@NonNull Broadcast vectorsConfigurationBroadcast, - @NonNull Broadcast paramServerConfigurationBroadcast) { - this.vectorsConfigurationBroadcast = vectorsConfigurationBroadcast; - this.paramServerConfigurationBroadcast = paramServerConfigurationBroadcast; - } - - @Override - public Iterable call(Sequence sequence) throws Exception { - if (configuration == null) - configuration = vectorsConfigurationBroadcast.getValue(); - - if (ela == null) { - try { - ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()) - .newInstance(); - } catch (Exception e) { - throw new RuntimeException(e); - } + if (ela == null) { + try { + ela = (SparkElementsLearningAlgorithm) Class.forName(configuration.getElementsLearningAlgorithm()) + .newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); } - driver = ela.getTrainingDriver(); - - // we just silently initialize server - VoidParameterServer.getInstance().init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), - driver); - - // TODO: call for initializeSeqVec here - - List elements = new ArrayList<>(); - - elements.addAll(sequence.getElements()); - - // FIXME: this is PROBABLY bad, we might want to ensure, there's no duplicates. - if (configuration.isTrainSequenceVectors()) - if (!sequence.getSequenceLabels().isEmpty()) - elements.addAll(sequence.getSequenceLabels()); - - return elements; } + driver = ela.getTrainingDriver(); + + // we just silently initialize server + VoidParameterServer.getInstance().init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), + driver); + + // TODO: call for initializeSeqVec here + + List elements = new ArrayList<>(); + + elements.addAll(sequence.getElements()); + + // FIXME: this is PROBABLY bad, we might want to ensure, there's no duplicates. + if (configuration.isTrainSequenceVectors()) + if (!sequence.getSequenceLabels().isEmpty()) + elements.addAll(sequence.getSequenceLabels()); + + return elements.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java index 4718d0118..bbad9f2e3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java @@ -16,28 +16,252 @@ package org.deeplearning4j.spark.models.embeddings.word2vec; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; import scala.Tuple2; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; +import java.util.concurrent.atomic.AtomicLong; /** * @author jeffreytang * @author raver119@gmail.com */ -public class FirstIterationFunction extends - BaseFlatMapFunctionAdaptee, Long>>, Entry> { +public class FirstIterationFunction implements + FlatMapFunction, Long>>, Entry> { + + private int ithIteration = 1; + private int vectorLength; + private boolean useAdaGrad; + private int batchSize = 0; + private double negative; + private int window; + private double alpha; + private double minAlpha; + private long totalWordCount; + private long seed; + private int maxExp; + private double[] expTable; + private int iterations; + private Map indexSyn0VecMap; + private Map pointSyn1VecMap; + private AtomicLong nextRandom = new AtomicLong(5); + + private volatile VocabCache vocab; + private volatile NegativeHolder negativeHolder; + private AtomicLong cid = new AtomicLong(0); + private AtomicLong aff = new AtomicLong(0); + + public FirstIterationFunction(Broadcast> word2vecVarMapBroadcast, - Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { - super(new FirstIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast)); + Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { + + Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); + this.expTable = expTableBroadcast.getValue(); + this.vectorLength = (int) word2vecVarMap.get("vectorLength"); + this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); + this.negative = (double) word2vecVarMap.get("negative"); + this.window = (int) word2vecVarMap.get("window"); + this.alpha = (double) word2vecVarMap.get("alpha"); + this.minAlpha = (double) word2vecVarMap.get("minAlpha"); + this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); + this.seed = (long) word2vecVarMap.get("seed"); + this.maxExp = (int) word2vecVarMap.get("maxExp"); + this.iterations = (int) word2vecVarMap.get("iterations"); + this.batchSize = (int) word2vecVarMap.get("batchSize"); + this.indexSyn0VecMap = new HashMap<>(); + this.pointSyn1VecMap = new HashMap<>(); + this.vocab = vocabCacheBroadcast.getValue(); + + if (this.vocab == null) + throw new RuntimeException("VocabCache is null"); + + if (negative > 0) { + negativeHolder = NegativeHolder.getInstance(); + negativeHolder.initHolder(vocab, expTable, this.vectorLength); + } + } + + + + @Override + public Iterator> call(Iterator, Long>> pairIter) { + while (pairIter.hasNext()) { + List, Long>> batch = new ArrayList<>(); + while (pairIter.hasNext() && batch.size() < batchSize) { + Tuple2, Long> pair = pairIter.next(); + List vocabWordsList = pair._1(); + Long sentenceCumSumCount = pair._2(); + batch.add(Pair.of(vocabWordsList, sentenceCumSumCount)); + } + + for (int i = 0; i < iterations; i++) { + //System.out.println("Training sentence: " + vocabWordsList); + for (Pair, Long> pair : batch) { + List vocabWordsList = pair.getKey(); + Long sentenceCumSumCount = pair.getValue(); + double currentSentenceAlpha = Math.max(minAlpha, + alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount)); + trainSentence(vocabWordsList, currentSentenceAlpha); + } + } + } + return indexSyn0VecMap.entrySet().iterator(); + } + + + public void trainSentence(List vocabWordsList, double currentSentenceAlpha) { + + if (vocabWordsList != null && !vocabWordsList.isEmpty()) { + for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ithWordInSentence++) { + // Random value ranging from 0 to window size + nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); + int b = (int) (long) this.nextRandom.get() % window; + VocabWord currentWord = vocabWordsList.get(ithWordInSentence); + if (currentWord != null) { + skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha); + } + } + } + } + + public void skipGram(int ithWordInSentence, List vocabWordsList, int b, double currentSentenceAlpha) { + + VocabWord currentWord = vocabWordsList.get(ithWordInSentence); + if (currentWord != null && !vocabWordsList.isEmpty()) { + int end = window * 2 + 1 - b; + for (int a = b; a < end; a++) { + if (a != window) { + int c = ithWordInSentence - window + a; + if (c >= 0 && c < vocabWordsList.size()) { + VocabWord lastWord = vocabWordsList.get(c); + iterateSample(currentWord, lastWord, currentSentenceAlpha); + } + } + } + } + } + + public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { + + + if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) + return; + final int currentWordIndex = w2.getIndex(); + + // error for current word and context + INDArray neu1e = Nd4j.create(vectorLength); + + // First iteration Syn0 is random numbers + INDArray l1 = null; + if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { + l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); + } else { + l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); + } + + // + for (int i = 0; i < w1.getCodeLength(); i++) { + int code = w1.getCodes().get(i); + int point = w1.getPoints().get(i); + if (point < 0) + throw new IllegalStateException("Illegal point " + point); + // Point to + INDArray syn1; + if (pointSyn1VecMap.containsKey(point)) { + syn1 = pointSyn1VecMap.get(point); + } else { + syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros + pointSyn1VecMap.put(point, syn1); + } + + // Dot product of Syn0 and Syn1 vecs + double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); + + if (dot < -maxExp || dot >= maxExp) + continue; + + int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); + + if (idx >= expTable.length) + continue; + + //score + double f = expTable[idx]; + //gradient + double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) + : currentSentenceAlpha); + + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); + } + + int target = w1.getIndex(); + int label; + //negative sampling + if (negative > 0) + for (int d = 0; d < negative + 1; d++) { + if (d == 0) + label = 1; + else { + nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); + + // FIXME: int cast + int idx = Math.abs((int) (nextRandom.get() >> 16) % (int) negativeHolder.getTable().length()); + + target = negativeHolder.getTable().getInt(idx); + if (target <= 0) + target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; + + if (target == w1.getIndex()) + continue; + label = 0; + } + + if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) + continue; + + double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target)); + double g; + if (f > maxExp) + g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; + else if (f < -maxExp) + g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); + else { + int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); + if (idx >= expTable.length) + continue; + + g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) + : (label - expTable[idx]) * alpha; + } + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, negativeHolder.getSyn1Neg().slice(target), neu1e); + + Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, negativeHolder.getSyn1Neg().slice(target)); + } + + + // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. + Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); + + VocabWord word = vocab.elementAtIndex(currentWordIndex); + indexSyn0VecMap.put(word, l1); + } + + private INDArray getRandomSyn0Vec(int vectorLength, long lseed) { + /* + we use wordIndex as part of seed here, to guarantee that during word syn0 initialization on dwo distinct nodes, initial weights will be the same for the same word + */ + return Nd4j.rand( new int[] {1, vectorLength}, lseed * seed).subi(0.5).divi(vectorLength); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java deleted file mode 100644 index 05ca83fc6..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunctionAdapter.java +++ /dev/null @@ -1,265 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.models.embeddings.word2vec; - -import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.deeplearning4j.models.word2vec.VocabWord; -import org.deeplearning4j.models.word2vec.wordstore.VocabCache; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.Pair; -import scala.Tuple2; - -import java.util.*; -import java.util.concurrent.atomic.AtomicLong; - -/** - * @author jeffreytang - * @author raver119@gmail.com - */ -public class FirstIterationFunctionAdapter implements - FlatMapFunctionAdapter, Long>>, Map.Entry> { - - private int ithIteration = 1; - private int vectorLength; - private boolean useAdaGrad; - private int batchSize = 0; - private double negative; - private int window; - private double alpha; - private double minAlpha; - private long totalWordCount; - private long seed; - private int maxExp; - private double[] expTable; - private int iterations; - private Map indexSyn0VecMap; - private Map pointSyn1VecMap; - private AtomicLong nextRandom = new AtomicLong(5); - - private volatile VocabCache vocab; - private volatile NegativeHolder negativeHolder; - private AtomicLong cid = new AtomicLong(0); - private AtomicLong aff = new AtomicLong(0); - - - - public FirstIterationFunctionAdapter(Broadcast> word2vecVarMapBroadcast, - Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { - - Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); - this.expTable = expTableBroadcast.getValue(); - this.vectorLength = (int) word2vecVarMap.get("vectorLength"); - this.useAdaGrad = (boolean) word2vecVarMap.get("useAdaGrad"); - this.negative = (double) word2vecVarMap.get("negative"); - this.window = (int) word2vecVarMap.get("window"); - this.alpha = (double) word2vecVarMap.get("alpha"); - this.minAlpha = (double) word2vecVarMap.get("minAlpha"); - this.totalWordCount = (long) word2vecVarMap.get("totalWordCount"); - this.seed = (long) word2vecVarMap.get("seed"); - this.maxExp = (int) word2vecVarMap.get("maxExp"); - this.iterations = (int) word2vecVarMap.get("iterations"); - this.batchSize = (int) word2vecVarMap.get("batchSize"); - this.indexSyn0VecMap = new HashMap<>(); - this.pointSyn1VecMap = new HashMap<>(); - this.vocab = vocabCacheBroadcast.getValue(); - - if (this.vocab == null) - throw new RuntimeException("VocabCache is null"); - - if (negative > 0) { - negativeHolder = NegativeHolder.getInstance(); - negativeHolder.initHolder(vocab, expTable, this.vectorLength); - } - } - - - - @Override - public Iterable> call(Iterator, Long>> pairIter) { - while (pairIter.hasNext()) { - List, Long>> batch = new ArrayList<>(); - while (pairIter.hasNext() && batch.size() < batchSize) { - Tuple2, Long> pair = pairIter.next(); - List vocabWordsList = pair._1(); - Long sentenceCumSumCount = pair._2(); - batch.add(Pair.of(vocabWordsList, sentenceCumSumCount)); - } - - for (int i = 0; i < iterations; i++) { - //System.out.println("Training sentence: " + vocabWordsList); - for (Pair, Long> pair : batch) { - List vocabWordsList = pair.getKey(); - Long sentenceCumSumCount = pair.getValue(); - double currentSentenceAlpha = Math.max(minAlpha, - alpha - (alpha - minAlpha) * (sentenceCumSumCount / (double) totalWordCount)); - trainSentence(vocabWordsList, currentSentenceAlpha); - } - } - } - return indexSyn0VecMap.entrySet(); - } - - - public void trainSentence(List vocabWordsList, double currentSentenceAlpha) { - - if (vocabWordsList != null && !vocabWordsList.isEmpty()) { - for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ithWordInSentence++) { - // Random value ranging from 0 to window size - nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - int b = (int) (long) this.nextRandom.get() % window; - VocabWord currentWord = vocabWordsList.get(ithWordInSentence); - if (currentWord != null) { - skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha); - } - } - } - } - - public void skipGram(int ithWordInSentence, List vocabWordsList, int b, double currentSentenceAlpha) { - - VocabWord currentWord = vocabWordsList.get(ithWordInSentence); - if (currentWord != null && !vocabWordsList.isEmpty()) { - int end = window * 2 + 1 - b; - for (int a = b; a < end; a++) { - if (a != window) { - int c = ithWordInSentence - window + a; - if (c >= 0 && c < vocabWordsList.size()) { - VocabWord lastWord = vocabWordsList.get(c); - iterateSample(currentWord, lastWord, currentSentenceAlpha); - } - } - } - } - } - - public void iterateSample(VocabWord w1, VocabWord w2, double currentSentenceAlpha) { - - - if (w1 == null || w2 == null || w2.getIndex() < 0 || w2.getIndex() == w1.getIndex()) - return; - final int currentWordIndex = w2.getIndex(); - - // error for current word and context - INDArray neu1e = Nd4j.create(vectorLength); - - // First iteration Syn0 is random numbers - INDArray l1 = null; - if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { - l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); - } else { - l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); - } - - // - for (int i = 0; i < w1.getCodeLength(); i++) { - int code = w1.getCodes().get(i); - int point = w1.getPoints().get(i); - if (point < 0) - throw new IllegalStateException("Illegal point " + point); - // Point to - INDArray syn1; - if (pointSyn1VecMap.containsKey(point)) { - syn1 = pointSyn1VecMap.get(point); - } else { - syn1 = Nd4j.zeros(1, vectorLength); // 1 row of vector length of zeros - pointSyn1VecMap.put(point, syn1); - } - - // Dot product of Syn0 and Syn1 vecs - double dot = Nd4j.getBlasWrapper().level1().dot(vectorLength, 1.0, l1, syn1); - - if (dot < -maxExp || dot >= maxExp) - continue; - - int idx = (int) ((dot + maxExp) * ((double) expTable.length / maxExp / 2.0)); - - if (idx >= expTable.length) - continue; - - //score - double f = expTable[idx]; - //gradient - double g = (1 - code - f) * (useAdaGrad ? w1.getGradient(i, currentSentenceAlpha, currentSentenceAlpha) - : currentSentenceAlpha); - - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, syn1, neu1e); - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, syn1); - } - - int target = w1.getIndex(); - int label; - //negative sampling - if (negative > 0) - for (int d = 0; d < negative + 1; d++) { - if (d == 0) - label = 1; - else { - nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - - // FIXME: int cast - int idx = Math.abs((int) (nextRandom.get() >> 16) % (int) negativeHolder.getTable().length()); - - target = negativeHolder.getTable().getInt(idx); - if (target <= 0) - target = (int) nextRandom.get() % (vocab.numWords() - 1) + 1; - - if (target == w1.getIndex()) - continue; - label = 0; - } - - if (target >= negativeHolder.getSyn1Neg().rows() || target < 0) - continue; - - double f = Nd4j.getBlasWrapper().dot(l1, negativeHolder.getSyn1Neg().slice(target)); - double g; - if (f > maxExp) - g = useAdaGrad ? w1.getGradient(target, (label - 1), alpha) : (label - 1) * alpha; - else if (f < -maxExp) - g = label * (useAdaGrad ? w1.getGradient(target, alpha, alpha) : alpha); - else { - int idx = (int) ((f + maxExp) * (expTable.length / maxExp / 2)); - if (idx >= expTable.length) - continue; - - g = useAdaGrad ? w1.getGradient(target, label - expTable[idx], alpha) - : (label - expTable[idx]) * alpha; - } - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, negativeHolder.getSyn1Neg().slice(target), neu1e); - - Nd4j.getBlasWrapper().level1().axpy(vectorLength, g, l1, negativeHolder.getSyn1Neg().slice(target)); - } - - - // Updated the Syn0 vector based on gradient. Syn0 is not random anymore. - Nd4j.getBlasWrapper().level1().axpy(vectorLength, 1.0f, neu1e, l1); - - VocabWord word = vocab.elementAtIndex(currentWordIndex); - indexSyn0VecMap.put(word, l1); - } - - private INDArray getRandomSyn0Vec(int vectorLength, long lseed) { - /* - we use wordIndex as part of seed here, to guarantee that during word syn0 initialization on dwo distinct nodes, initial weights will be the same for the same word - */ - return Nd4j.rand( new int[] {1, vectorLength}, lseed * seed).subi(0.5).divi(vectorLength); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java index 50ec16ff9..c34156484 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.models.embeddings.word2vec; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.nd4j.linalg.api.ndarray.INDArray; @@ -37,22 +36,7 @@ import java.util.concurrent.atomic.AtomicLong; * @author jeffreytang * @author raver119@gmail.com */ -public class SecondIterationFunction extends - BaseFlatMapFunctionAdaptee, Long>>, Entry> { - - public SecondIterationFunction(Broadcast> word2vecVarMapBroadcast, - Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { - super(new SecondIterationFunctionAdapter(word2vecVarMapBroadcast, expTableBroadcast, vocabCacheBroadcast)); - } -} - - -/** - * @author jeffreytang - * @author raver119@gmail.com - */ -class SecondIterationFunctionAdapter - implements FlatMapFunctionAdapter, Long>>, Entry> { +public class SecondIterationFunction implements FlatMapFunction, Long>>, Entry> { private int ithIteration = 1; private int vectorLength; @@ -78,7 +62,7 @@ class SecondIterationFunctionAdapter - public SecondIterationFunctionAdapter(Broadcast> word2vecVarMapBroadcast, + public SecondIterationFunction(Broadcast> word2vecVarMapBroadcast, Broadcast expTableBroadcast, Broadcast> vocabCacheBroadcast) { Map word2vecVarMap = word2vecVarMapBroadcast.getValue(); @@ -110,7 +94,7 @@ class SecondIterationFunctionAdapter @Override - public Iterable> call(Iterator, Long>> pairIter) { + public Iterator> call(Iterator, Long>> pairIter) { this.vocabHolder = VocabHolder.getInstance(); this.vocabHolder.setSeed(seed, vectorLength); @@ -139,7 +123,7 @@ class SecondIterationFunctionAdapter } } } - return vocabHolder.getSplit(vocab); + return vocabHolder.getSplit(vocab).iterator(); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml index 0d82bf48c..d8f425286 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml @@ -32,42 +32,6 @@ UTF-8 - - - - - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.module - jackson-module-scala_${scala.binary.version} - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${spark2.jackson.version} - - - - org.nd4j diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java index 135cfa1c4..900f0e63b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.parameterserver.functions; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper; @@ -32,28 +31,20 @@ import java.util.Iterator; * @author raver119@gmail.com */ -public class SharedFlatMapDataSet extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapDataSet(TrainingWorker worker) { - super(new SharedFlatMapDataSetAdapter(worker)); - } -} - - -class SharedFlatMapDataSetAdapter implements FlatMapFunctionAdapter, R> { +public class SharedFlatMapDataSet implements FlatMapFunction, R> { private final SharedTrainingWorker worker; - public SharedFlatMapDataSetAdapter(TrainingWorker worker) { + public SharedFlatMapDataSet(TrainingWorker worker) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } /* @@ -70,6 +61,6 @@ class SharedFlatMapDataSetAdapter implements FlatMapFu // all threads in this executor will be blocked here until training finished SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java index d7de24f15..5ce338b0f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.parameterserver.functions; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper; @@ -31,30 +30,20 @@ import java.util.Iterator; /** * Created by raver119 on 13.06.17. */ -public class SharedFlatMapMultiDataSet - extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapMultiDataSet(TrainingWorker worker) { - super(new SharedFlatMapMultiDataSetAdapter(worker)); - } -} - - -class SharedFlatMapMultiDataSetAdapter - implements FlatMapFunctionAdapter, R> { +public class SharedFlatMapMultiDataSet implements FlatMapFunction, R> { private final SharedTrainingWorker worker; - public SharedFlatMapMultiDataSetAdapter(TrainingWorker worker) { + public SharedFlatMapMultiDataSet(TrainingWorker worker) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } /* That's the place where we do our stuff. Here's the plan: @@ -70,6 +59,6 @@ class SharedFlatMapMultiDataSetAdapter // all threads in this executor will be blocked here until training finished SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java index 5028e077f..4c8192ae7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java @@ -18,9 +18,8 @@ package org.deeplearning4j.spark.parameterserver.functions; import org.apache.commons.io.LineIterator; import org.apache.hadoop.conf.Configuration; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.DataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -39,11 +38,7 @@ import java.util.Iterator; * * @author raver119@gmail.com */ -public class SharedFlatMapPaths extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapPaths(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { - super(new SharedFlatMapPathsAdapter(worker, loader, hadoopConfig)); - } +public class SharedFlatMapPaths implements FlatMapFunction, R> { public static File toTempFile(Iterator dataSetIterator) throws IOException { File f = Files.createTempFile("SharedFlatMapPaths",".txt").toFile(); @@ -56,17 +51,14 @@ public class SharedFlatMapPaths extends BaseFlatMapFun } return f; } -} - -class SharedFlatMapPathsAdapter implements FlatMapFunctionAdapter, R> { public static Configuration defaultConfig; protected final SharedTrainingWorker worker; protected final DataSetLoader loader; protected final Broadcast hadoopConfig; - public SharedFlatMapPathsAdapter(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { + public SharedFlatMapPaths(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; this.loader = loader; @@ -74,10 +66,10 @@ class SharedFlatMapPathsAdapter implements FlatMapFunc } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } // here we'll be converting out Strings coming out of iterator to DataSets // PathSparkDataSetIterator does that for us @@ -93,7 +85,7 @@ class SharedFlatMapPathsAdapter implements FlatMapFunc // first callee will become master, others will obey and die SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } finally { lineIter.close(); f.delete(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java index a8fbadb2b..3a8b1c213 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.parameterserver.functions; import org.apache.commons.io.LineIterator; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.MultiDataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -37,21 +36,13 @@ import java.util.Iterator; /** * @author raver119@gmail.com */ -public class SharedFlatMapPathsMDS extends BaseFlatMapFunctionAdaptee, R> { - - public SharedFlatMapPathsMDS(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { - super(new SharedFlatMapPathsMDSAdapter(worker, loader, hadoopConfig)); - } -} - - -class SharedFlatMapPathsMDSAdapter implements FlatMapFunctionAdapter, R> { +public class SharedFlatMapPathsMDS implements FlatMapFunction, R> { protected final SharedTrainingWorker worker; protected final MultiDataSetLoader loader; protected final Broadcast hadoopConfig; - public SharedFlatMapPathsMDSAdapter(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { + public SharedFlatMapPathsMDS(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { // we're not going to have anything but Shared classes here ever this.worker = (SharedTrainingWorker) worker; this.loader = loader; @@ -59,10 +50,10 @@ class SharedFlatMapPathsMDSAdapter implements FlatMapF } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { //Under some limited circumstances, we might have an empty partition. In this case, we should return immediately if(!dataSetIterator.hasNext()){ - return Collections.emptyList(); + return Collections.emptyIterator(); } // here we'll be converting out Strings coming out of iterator to DataSets // PathSparkDataSetIterator does that for us @@ -78,7 +69,7 @@ class SharedFlatMapPathsMDSAdapter implements FlatMapF // first callee will become master, others will obey and die SharedTrainingResult result = SharedTrainingWrapper.getInstance(worker.getInstanceId()).run(worker); - return Collections.singletonList((R) result); + return Collections.singletonList((R) result).iterator(); } finally { lineIter.close(); f.delete(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java index e2c371c06..c3910968e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.api.worker; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -41,30 +40,16 @@ import java.util.Iterator; * * @author Alex Black */ -public class ExecuteWorkerFlatMap extends BaseFlatMapFunctionAdaptee, R> { - - public ExecuteWorkerFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerFlatMapAdapter(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on DataSets. - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -class ExecuteWorkerFlatMapAdapter implements FlatMapFunctionAdapter, R> { +public class ExecuteWorkerFlatMap implements FlatMapFunction, R> { private final TrainingWorker worker; - public ExecuteWorkerFlatMapAdapter(TrainingWorker worker) { + public ExecuteWorkerFlatMap(TrainingWorker worker) { this.worker = worker; } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { WorkerConfiguration dataConfig = worker.getDataConfiguration(); final boolean isGraph = dataConfig.isGraphNetwork(); @@ -79,9 +64,9 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu Pair pair = worker.getFinalResultNoDataWithStats(); pair.getFirst().setStats(s.build(pair.getSecond())); - return Collections.singletonList(pair.getFirst()); + return Collections.singletonList(pair.getFirst()).iterator(); } else { - return Collections.singletonList(worker.getFinalResultNoData()); + return Collections.singletonList(worker.getFinalResultNoData()).iterator(); } } @@ -131,7 +116,7 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu SparkTrainingStats returnStats = s.build(workerStats); result.getFirst().setStats(returnStats); - return Collections.singletonList(result.getFirst()); + return Collections.singletonList(result.getFirst()).iterator(); } } else { R result; @@ -141,7 +126,7 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu result = worker.processMinibatch(next, net, !batchedIterator.hasNext()); if (result != null) { //Terminate training immediately - return Collections.singletonList(result); + return Collections.singletonList(result).iterator(); } } } @@ -155,12 +140,12 @@ class ExecuteWorkerFlatMapAdapter implements FlatMapFu else pair = worker.getFinalResultWithStats(net); pair.getFirst().setStats(s.build(pair.getSecond())); - return Collections.singletonList(pair.getFirst()); + return Collections.singletonList(pair.getFirst()).iterator(); } else { if (isGraph) - return Collections.singletonList(worker.getFinalResult(graph)); + return Collections.singletonList(worker.getFinalResult(graph)).iterator(); else - return Collections.singletonList(worker.getFinalResult(net)); + return Collections.singletonList(worker.getFinalResult(net)).iterator(); } } finally { //Make sure we shut down the async thread properly... diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java index 15fcd6b89..6fa148394 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java @@ -16,8 +16,8 @@ package org.deeplearning4j.spark.api.worker; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import lombok.AllArgsConstructor; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -39,31 +39,13 @@ import java.util.Iterator; * * @author Alex Black */ -public class ExecuteWorkerMultiDataSetFlatMap - extends BaseFlatMapFunctionAdaptee, R> { - - public ExecuteWorkerMultiDataSetFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on MultiDataSets. Used only in SparkComputationGraph implementation. - * - * @author Alex Black - */ -class ExecuteWorkerMultiDataSetFlatMapAdapter - implements FlatMapFunctionAdapter, R> { +@AllArgsConstructor +public class ExecuteWorkerMultiDataSetFlatMap implements FlatMapFunction, R> { private final TrainingWorker worker; - public ExecuteWorkerMultiDataSetFlatMapAdapter(TrainingWorker worker) { - this.worker = worker; - } - @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { WorkerConfiguration dataConfig = worker.getDataConfiguration(); boolean stats = dataConfig.isCollectTrainingStats(); @@ -75,7 +57,7 @@ class ExecuteWorkerMultiDataSetFlatMapAdapter if (stats) s.logReturnTime(); //TODO return the results... - return Collections.emptyList(); //Sometimes: no data + return Collections.emptyIterator(); //Sometimes: no data } int batchSize = dataConfig.getBatchSizePerWorker(); @@ -118,13 +100,13 @@ class ExecuteWorkerMultiDataSetFlatMapAdapter SparkTrainingStats returnStats = s.build(workerStats); result.getFirst().setStats(returnStats); - return Collections.singletonList(result.getFirst()); + return Collections.singletonList(result.getFirst()).iterator(); } } else { R result = worker.processMinibatch(next, net, !batchedIterator.hasNext()); if (result != null) { //Terminate training immediately - return Collections.singletonList(result); + return Collections.singletonList(result).iterator(); } } } @@ -134,9 +116,9 @@ class ExecuteWorkerMultiDataSetFlatMapAdapter s.logReturnTime(); Pair pair = worker.getFinalResultWithStats(net); pair.getFirst().setStats(s.build(pair.getSecond())); - return Collections.singletonList(pair.getFirst()); + return Collections.singletonList(pair.getFirst()).iterator(); } else { - return Collections.singletonList(worker.getFinalResult(net)); + return Collections.singletonList(worker.getFinalResult(net)).iterator(); } } finally { Nd4j.getExecutioner().commit(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java index 2e1bdb646..4969a055b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.input.PortableDataStream; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator; @@ -33,32 +32,15 @@ import java.util.Iterator; * @author Alex Black */ @Deprecated -public class ExecuteWorkerPDSFlatMap - extends BaseFlatMapFunctionAdaptee, R> { +public class ExecuteWorkerPDSFlatMap implements FlatMapFunction, R> { + private final FlatMapFunction, R> workerFlatMap; public ExecuteWorkerPDSFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerPDSFlatMapAdapter<>(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded using a PortableDataStream - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -@Deprecated -class ExecuteWorkerPDSFlatMapAdapter - implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; - - public ExecuteWorkerPDSFlatMapAdapter(TrainingWorker worker) { - this.workerFlatMap = new ExecuteWorkerFlatMapAdapter<>(worker); + this.workerFlatMap = new ExecuteWorkerFlatMap<>(worker); } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { return workerFlatMap.call(new PortableDataStreamDataSetIterator(iter)); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java index ebc4d4691..63b82bdaa 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.input.PortableDataStream; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.api.TrainingResult; import org.deeplearning4j.spark.api.TrainingWorker; import org.deeplearning4j.spark.iterator.PortableDataStreamMultiDataSetIterator; @@ -33,32 +32,15 @@ import java.util.Iterator; * @author Alex Black */ @Deprecated -public class ExecuteWorkerPDSMDSFlatMap - extends BaseFlatMapFunctionAdaptee, R> { +public class ExecuteWorkerPDSMDSFlatMap implements FlatMapFunction, R> { + private final FlatMapFunction, R> workerFlatMap; public ExecuteWorkerPDSMDSFlatMap(TrainingWorker worker) { - super(new ExecuteWorkerPDSMDSFlatMapAdapter<>(worker)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized MultiDataSet objects, that can be loaded using a PortableDataStream - * Used for SparkComputationGraph implementations only - * - * @author Alex Black - */ -@Deprecated -class ExecuteWorkerPDSMDSFlatMapAdapter - implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; - - public ExecuteWorkerPDSMDSFlatMapAdapter(TrainingWorker worker) { - this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker); + this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMap<>(worker); } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { return workerFlatMap.call(new PortableDataStreamMultiDataSetIterator(iter)); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java index 5e3889a99..e26615ab8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.DataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -38,30 +37,15 @@ import java.util.List; * * @author Alex Black */ -public class ExecuteWorkerPathFlatMap - extends BaseFlatMapFunctionAdaptee, R> { +public class ExecuteWorkerPathFlatMap implements FlatMapFunction, R> { - public ExecuteWorkerPathFlatMap(TrainingWorker worker, DataSetLoader loader, Broadcast hadoopConfig) { - super(new ExecuteWorkerPathFlatMapAdapter<>(worker, loader, hadoopConfig)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded from a path (local or HDFS) - * that is specified as a String - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -class ExecuteWorkerPathFlatMapAdapter implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; + private final FlatMapFunction, R> workerFlatMap; private final DataSetLoader dataSetLoader; private final int maxDataSetObjects; private final Broadcast hadoopConfig; - public ExecuteWorkerPathFlatMapAdapter(TrainingWorker worker, DataSetLoader dataSetLoader, Broadcast hadoopConfig) { - this.workerFlatMap = new ExecuteWorkerFlatMapAdapter<>(worker); + public ExecuteWorkerPathFlatMap(TrainingWorker worker, DataSetLoader dataSetLoader, Broadcast hadoopConfig) { + this.workerFlatMap = new ExecuteWorkerFlatMap<>(worker); this.dataSetLoader = dataSetLoader; this.hadoopConfig = hadoopConfig; @@ -84,7 +68,7 @@ class ExecuteWorkerPathFlatMapAdapter implements FlatM } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { List list = new ArrayList<>(); int count = 0; while (iter.hasNext() && count++ < maxDataSetObjects) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java index 072425f18..47e53bd6c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.api.worker; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.MultiDataSetLoader; import org.deeplearning4j.spark.api.TrainingResult; @@ -38,31 +37,14 @@ import java.util.List; * * @author Alex Black */ -public class ExecuteWorkerPathMDSFlatMap - extends BaseFlatMapFunctionAdaptee, R> { - - public ExecuteWorkerPathMDSFlatMap(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { - super(new ExecuteWorkerPathMDSFlatMapAdapter<>(worker, loader, hadoopConfig)); - } -} - - -/** - * A FlatMapFunction for executing training on serialized DataSet objects, that can be loaded from a path (local or HDFS) - * that is specified as a String - * Used in both SparkDl4jMultiLayer and SparkComputationGraph implementations - * - * @author Alex Black - */ -class ExecuteWorkerPathMDSFlatMapAdapter - implements FlatMapFunctionAdapter, R> { - private final FlatMapFunctionAdapter, R> workerFlatMap; +public class ExecuteWorkerPathMDSFlatMap implements FlatMapFunction, R> { + private final FlatMapFunction, R> workerFlatMap; private MultiDataSetLoader loader; private final int maxDataSetObjects; private final Broadcast hadoopConfig; - public ExecuteWorkerPathMDSFlatMapAdapter(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { - this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMapAdapter<>(worker); + public ExecuteWorkerPathMDSFlatMap(TrainingWorker worker, MultiDataSetLoader loader, Broadcast hadoopConfig) { + this.workerFlatMap = new ExecuteWorkerMultiDataSetFlatMap<>(worker); this.loader = loader; this.hadoopConfig = hadoopConfig; @@ -85,7 +67,7 @@ class ExecuteWorkerPathMDSFlatMapAdapter } @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { List list = new ArrayList<>(); int count = 0; while (iter.hasNext() && count++ < maxDataSetObjects) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java index fd3a5a2fc..9513193e4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java @@ -16,8 +16,8 @@ package org.deeplearning4j.spark.data; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import lombok.AllArgsConstructor; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.DataSet; import java.util.ArrayList; @@ -38,37 +38,12 @@ import java.util.List; * * @author Alex Black */ -public class BatchDataSetsFunction extends BaseFlatMapFunctionAdaptee, DataSet> { - - public BatchDataSetsFunction(int minibatchSize) { - super(new BatchDataSetsFunctionAdapter(minibatchSize)); - } -} - - -/** - * Function used to batch DataSet objects together. Typically used to combine singe-example DataSet objects out of - * something like {@link org.deeplearning4j.spark.datavec.DataVecDataSetFunction} together into minibatches.
- * - * Usage: - *
- * {@code
- *      RDD mySingleExampleDataSets = ...;
- *      RDD batchData = mySingleExampleDataSets.mapPartitions(new BatchDataSetsFunction(batchSize));
- * }
- * 
- * - * @author Alex Black - */ -class BatchDataSetsFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { +@AllArgsConstructor +public class BatchDataSetsFunction implements FlatMapFunction, DataSet> { private final int minibatchSize; - public BatchDataSetsFunctionAdapter(int minibatchSize) { - this.minibatchSize = minibatchSize; - } - @Override - public Iterable call(Iterator iter) throws Exception { + public Iterator call(Iterator iter) throws Exception { List out = new ArrayList<>(); while (iter.hasNext()) { List list = new ArrayList<>(); @@ -88,6 +63,6 @@ class BatchDataSetsFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { - - public SplitDataSetsFunction() { - super(new SplitDataSetsFunctionAdapter()); - } -} - - -/** - * Take an existing DataSet object, and split it into multiple DataSet objects with one example in each - * - * Usage: - *
- * {@code
- *      RDD myBatchedExampleDataSets = ...;
- *      RDD singleExamlpeDataSets = myBatchedExampleDataSets.mapPartitions(new SplitDataSets(batchSize));
- * }
- * 
- * - * @author Alex Black - */ -class SplitDataSetsFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { +public class SplitDataSetsFunction implements FlatMapFunction, DataSet> { @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { List out = new ArrayList<>(); while (dataSetIterator.hasNext()) { out.addAll(dataSetIterator.next().asList()); } - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java index e0557d10c..1ccf54b91 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java @@ -17,12 +17,13 @@ package org.deeplearning4j.spark.data.shuffle; import org.apache.spark.api.java.JavaRDD; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.nd4j.linalg.dataset.DataSet; import scala.Tuple2; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Random; @@ -34,34 +35,17 @@ import java.util.Random; * * @author Alex Black */ -public class SplitDataSetExamplesPairFlatMapFunction extends BasePairFlatMapFunctionAdaptee { - - public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex) { - super(new SplitDataSetExamplesPairFlatMapFunctionAdapter(maxKeyIndex)); - } -} - - -/** - * A PairFlatMapFunction that splits each example in a {@link DataSet} object into its own {@link DataSet}. - * Also adds a random key (integer value) in the range 0 to maxKeyIndex-1.
- * - * Used in {@link org.deeplearning4j.spark.util.SparkUtils#shuffleExamples(JavaRDD, int, int)} - * - * @author Alex Black - */ -class SplitDataSetExamplesPairFlatMapFunctionAdapter - implements FlatMapFunctionAdapter> { +public class SplitDataSetExamplesPairFlatMapFunction implements PairFlatMapFunction { private transient Random r; private int maxKeyIndex; - public SplitDataSetExamplesPairFlatMapFunctionAdapter(int maxKeyIndex) { + public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex) { this.maxKeyIndex = maxKeyIndex; } @Override - public Iterable> call(DataSet dataSet) throws Exception { + public Iterator> call(DataSet dataSet) throws Exception { if (r == null) { r = new Random(); } @@ -72,6 +56,6 @@ class SplitDataSetExamplesPairFlatMapFunctionAdapter out.add(new Tuple2<>(r.nextInt(maxKeyIndex), ds)); } - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java index 37ef1e7ab..653bbc75d 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java @@ -16,9 +16,9 @@ package org.deeplearning4j.spark.datavec; +import lombok.AllArgsConstructor; import org.apache.spark.api.java.JavaRDD; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.FlatMapFunction; import org.nd4j.linalg.dataset.DataSet; import java.io.Serializable; @@ -44,22 +44,12 @@ public class RDDMiniBatches implements Serializable { return toSplitJava.mapPartitions(new MiniBatchFunction(miniBatches)); } - public static class MiniBatchFunction extends BaseFlatMapFunctionAdaptee, DataSet> { - - public MiniBatchFunction(int batchSize) { - super(new MiniBatchFunctionAdapter(batchSize)); - } - } - - static class MiniBatchFunctionAdapter implements FlatMapFunctionAdapter, DataSet> { - private int batchSize = 10; - - public MiniBatchFunctionAdapter(int batchSize) { - this.batchSize = batchSize; - } + @AllArgsConstructor + public static class MiniBatchFunction implements FlatMapFunction, DataSet> { + private int batchSize; @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { List ret = new ArrayList<>(); List temp = new ArrayList<>(); while (dataSetIterator.hasNext()) { @@ -74,10 +64,7 @@ public class RDDMiniBatches implements Serializable { if (temp.size() > 0) ret.add(DataSet.merge(temp)); - return ret; + return ret.iterator(); } - } - - } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java index 6d717371a..9412a3cbb 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java @@ -16,8 +16,8 @@ package org.deeplearning4j.spark.impl.common.repartition; -import com.google.common.base.Predicate; -import com.google.common.collect.Collections2; +import org.nd4j.shade.guava.base.Predicate; +import org.nd4j.shade.guava.collect.Collections2; import org.apache.spark.Partitioner; import scala.Tuple2; @@ -25,8 +25,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Random; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; +import static org.nd4j.shade.guava.base.Preconditions.checkArgument; +import static org.nd4j.shade.guava.base.Preconditions.checkNotNull; /** * This is a custom partitioner that rebalances a minimum of elements @@ -97,12 +97,13 @@ public class HashingBalancedPartitioner extends Partitioner { @Override public int numPartitions() { - return Collections2.filter(partitionWeightsByClass.get(0), new Predicate() { - @Override - public boolean apply(Double aDouble) { - return aDouble >= 0; - } - }).size(); + List list = partitionWeightsByClass.get(0); + int count = 0; + for(Double d : list){ + if(d >= 0) + count++; + } + return count; } @Override diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java index b540cff82..23b702444 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java @@ -16,8 +16,7 @@ package org.deeplearning4j.spark.impl.common.repartition; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.apache.spark.api.java.function.PairFlatMapFunction; import scala.Tuple2; import java.util.ArrayList; @@ -30,22 +29,14 @@ import java.util.List; * * @author Alex Black */ -public class MapTupleToPairFlatMap extends BasePairFlatMapFunctionAdaptee>, T, U> { - - public MapTupleToPairFlatMap() { - super(new MapTupleToPairFlatMapAdapter()); - } -} - - -class MapTupleToPairFlatMapAdapter implements FlatMapFunctionAdapter>, Tuple2> { +public class MapTupleToPairFlatMap implements PairFlatMapFunction>, T, U> { @Override - public Iterable> call(Iterator> tuple2Iterator) throws Exception { + public Iterator> call(Iterator> tuple2Iterator) throws Exception { List> list = new ArrayList<>(); while (tuple2Iterator.hasNext()) { list.add(tuple2Iterator.next()); } - return list; + return list.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java similarity index 87% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java rename to deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java index 58f092de8..99fcdc65b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunctionAdapter.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java @@ -27,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * @param Type of key, associated with each example. Used to keep track of which score belongs to which example * @author Alex Black */ -public abstract class BaseVaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeScoreWithKeyFunctionAdapter { +public abstract class BaseVaeReconstructionProbWithKeyFunction extends BaseVaeScoreWithKeyFunction { private final boolean useLogProbability; private final int numSamples; @@ -39,8 +39,8 @@ public abstract class BaseVaeReconstructionProbWithKeyFunctionAdapter extends * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} */ - public BaseVaeReconstructionProbWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, - boolean useLogProbability, int batchSize, int numSamples) { + public BaseVaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, + boolean useLogProbability, int batchSize, int numSamples) { super(params, jsonConfig, batchSize); this.useLogProbability = useLogProbability; this.numSamples = numSamples; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java similarity index 88% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java rename to deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java index 19b4e1b14..da6a374c4 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunctionAdapter.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java @@ -18,8 +18,9 @@ package org.deeplearning4j.spark.impl.common.score; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -38,8 +39,7 @@ import java.util.List; * @author Alex Black */ @Slf4j -public abstract class BaseVaeScoreWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { +public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunction>, K, Double> { protected final Broadcast params; protected final Broadcast jsonConfig; @@ -51,7 +51,7 @@ public abstract class BaseVaeScoreWithKeyFunctionAdapter * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use when scoring */ - public BaseVaeScoreWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { + public BaseVaeScoreWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.batchSize = batchSize; @@ -63,9 +63,9 @@ public abstract class BaseVaeScoreWithKeyFunctionAdapter @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } VariationalAutoencoder vae = getVaeLayer(); @@ -108,6 +108,6 @@ public abstract class BaseVaeScoreWithKeyFunctionAdapter log.debug("Scored {} examples ", totalCount); } - return ret; + return ret.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java index dd02c5d37..cdb41ba33 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java @@ -17,12 +17,10 @@ package org.deeplearning4j.spark.impl.graph.evaluation; import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.impl.evaluation.EvaluationRunner; import org.nd4j.evaluation.IEvaluation; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import java.util.Collections; @@ -36,26 +34,7 @@ import java.util.concurrent.Future; * * @author Alex Black */ -public class IEvaluateMDSFlatMapFunction - extends BaseFlatMapFunctionAdaptee, T[]> { - - public IEvaluateMDSFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, - T... evaluations) { - super(new IEvaluateMDSFlatMapFunctionAdapter<>(json, params, evalNumWorkers, evalBatchSize, evaluations)); - } -} - - -/** - * Function to evaluate data (using an IEvaluation instance), in a distributed manner - * Flat map function used to batch examples for computational efficiency + reduce number of IEvaluation objects returned - * for network efficiency. - * - * @author Alex Black - */ -@Slf4j -class IEvaluateMDSFlatMapFunctionAdapter - implements FlatMapFunctionAdapter, T[]> { +public class IEvaluateMDSFlatMapFunction implements FlatMapFunction, T[]> { protected Broadcast json; protected Broadcast params; @@ -70,7 +49,7 @@ class IEvaluateMDSFlatMapFunctionAdapter * this. Used to avoid doing too many at once (and hence memory issues) * @param evaluations Initial evaulation instance (i.e., empty Evaluation or RegressionEvaluation instance) */ - public IEvaluateMDSFlatMapFunctionAdapter(Broadcast json, Broadcast params, int evalNumWorkers, + public IEvaluateMDSFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, T[] evaluations) { this.json = json; this.params = params; @@ -80,13 +59,13 @@ class IEvaluateMDSFlatMapFunctionAdapter } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } if (!dataSetIterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } Future f = EvaluationRunner.getInstance().execute( @@ -94,9 +73,9 @@ class IEvaluateMDSFlatMapFunctionAdapter IEvaluation[] result = f.get(); if(result == null){ - return Collections.emptyList(); + return Collections.emptyIterator(); } else { - return Collections.singletonList((T[])result); + return Collections.singletonList((T[])result).iterator(); } } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java index cea6a7ab0..d520605a3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.impl.graph.evaluation; import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.api.loader.DataSetLoader; import org.deeplearning4j.api.loader.MultiDataSetLoader; @@ -43,26 +42,7 @@ import java.util.concurrent.Future; * * @author Alex Black */ -public class IEvaluateMDSPathsFlatMapFunction - extends BaseFlatMapFunctionAdaptee, IEvaluation[]> { - - public IEvaluateMDSPathsFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, - DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, - Broadcast configuration, IEvaluation... evaluations) { - super(new IEvaluateMDSPathsFlatMapFunctionAdapter(json, params, evalNumWorkers, evalBatchSize, dsLoader, mdsLoader, configuration, evaluations)); - } -} - - -/** - * Function to evaluate data (using an IEvaluation instance), in a distributed manner - * Flat map function used to batch examples for computational efficiency + reduce number of IEvaluation objects returned - * for network efficiency. - * - * @author Alex Black - */ -@Slf4j -class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter, IEvaluation[]> { +public class IEvaluateMDSPathsFlatMapFunction implements FlatMapFunction, IEvaluation[]> { protected Broadcast json; protected Broadcast params; @@ -80,7 +60,7 @@ class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter< * this. Used to avoid doing too many at once (and hence memory issues) * @param evaluations Initial evaulation instance (i.e., empty Evaluation or RegressionEvaluation instance) */ - public IEvaluateMDSPathsFlatMapFunctionAdapter(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, + public IEvaluateMDSPathsFlatMapFunction(Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, Broadcast configuration, IEvaluation[] evaluations) { this.json = json; this.params = params; @@ -93,9 +73,9 @@ class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter< } @Override - public Iterable call(Iterator paths) throws Exception { + public Iterator call(Iterator paths) throws Exception { if (!paths.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiDataSetIterator iter; @@ -109,9 +89,9 @@ class IEvaluateMDSPathsFlatMapFunctionAdapter implements FlatMapFunctionAdapter< Future f = EvaluationRunner.getInstance().execute(evaluations, evalNumWorkers, evalBatchSize, null, iter, true, json, params); IEvaluation[] result = f.get(); if(result == null){ - return Collections.emptyList(); + return Collections.emptyIterator(); } else { - return Collections.singletonList(result); + return Collections.singletonList(result).iterator(); } } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java index 107b22284..0e5f01343 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java @@ -21,7 +21,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunctionAdapter; +import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -33,7 +33,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * @author Alex Black * @see CGVaeReconstructionProbWithKeyFunction */ -public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunctionAdapter { +public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction { /** diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java index 4db797bb2..835bb8fa7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java @@ -21,7 +21,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunctionAdapter; +import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -31,7 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * * @author Alex Black */ -public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunctionAdapter { +public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunction { /** diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java index 3f84a6fb2..cec2f5b17 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java @@ -16,11 +16,13 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -40,48 +42,19 @@ import java.util.List; * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example * @author Alex Black */ -public class GraphFeedForwardWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>, K, INDArray[]> { - - public GraphFeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { - super(new GraphFeedForwardWithKeyFunctionAdapter(params, jsonConfig, batchSize)); - } -} - - -/** - * Function to feed-forward examples, and get the network output (for example, class probabilities). - * A key value is used to keey track of which output corresponds to which input. - * - * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example - * @author Alex Black - */ -class GraphFeedForwardWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { - - protected static Logger log = LoggerFactory.getLogger(GraphFeedForwardWithKeyFunction.class); +@Slf4j +@AllArgsConstructor +public class GraphFeedForwardWithKeyFunction implements PairFlatMapFunction>, K, INDArray[]> { private final Broadcast params; private final Broadcast jsonConfig; private final int batchSize; - /** - * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json - * @param batchSize Batch size to use for forward pass (use > 1 for efficiency) - */ - public GraphFeedForwardWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, - int batchSize) { - this.params = params; - this.jsonConfig = jsonConfig; - this.batchSize = batchSize; - } - @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); @@ -129,7 +102,7 @@ class GraphFeedForwardWithKeyFunctionAdapter } if (tupleCount == 0) { - return Collections.emptyList(); + return Collections.emptyIterator(); } List> output = new ArrayList<>(tupleCount); @@ -198,7 +171,7 @@ class GraphFeedForwardWithKeyFunctionAdapter Nd4j.getExecutioner().commit(); - return output; + return output.iterator(); } private INDArray getSubset(int exampleStart, int exampleEnd, INDArray from) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java index ee21d483e..44474248d 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java @@ -16,11 +16,12 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.DoubleFlatMapFunction; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.spark.util.BaseDoubleFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; @@ -41,31 +42,15 @@ import java.util.List; * @author Alex Black * @see ScoreExamplesWithKeyFunction */ -public class ScoreExamplesFunction extends BaseDoubleFlatMapFunctionAdaptee> { - - public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/**Function to score examples individually. Note that scoring is batched for computational efficiency.
- * This is essentially a Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
- * Note: This method returns a score for each example, but the association between examples and scores is lost. In - * cases where we need to know the score for particular examples, use {@link ScoreExamplesWithKeyFunction} - * @author Alex Black - * @see ScoreExamplesWithKeyFunction - */ -class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter, Double> { - protected static final Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class); +@Slf4j +public class ScoreExamplesFunction implements DoubleFlatMapFunction> { private final Broadcast params; private final Broadcast jsonConfig; private final boolean addRegularization; private final int batchSize; - public ScoreExamplesFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; @@ -75,9 +60,9 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter call(Iterator iterator) throws Exception { + public Iterator call(Iterator iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); @@ -121,6 +106,6 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter Type of key, associated with each example. Used to keep track of which score belongs to which example * @see ScoreExamplesFunction */ -public class ScoreExamplesWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesWithKeyFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/**Function to score examples individually, where each example is associated with a particular key
- * Note that scoring is batched for computational efficiency.
- * This is the Spark implementation of the {@link ComputationGraph#scoreExamples(MultiDataSet, boolean)} method
- * Note: The MultiDataSet objects passed in must have exactly one example in them (otherwise: can't have a 1:1 association - * between keys and data sets to score) - * @author Alex Black - * @param Type of key, associated with each example. Used to keep track of which score belongs to which example - * @see ScoreExamplesFunction - */ -class ScoreExamplesWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { - - protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class); +@Slf4j +public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction>, K, Double> { private final Broadcast params; private final Broadcast jsonConfig; @@ -78,7 +55,7 @@ class ScoreExamplesWithKeyFunctionAdapter * @param addRegularizationTerms if true: add regularization terms (l1/l2) if applicable; false: don't add regularization terms * @param batchSize Batch size to use when scoring examples */ - public ScoreExamplesWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; @@ -88,9 +65,9 @@ class ScoreExamplesWithKeyFunctionAdapter @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); @@ -140,6 +117,6 @@ class ScoreExamplesWithKeyFunctionAdapter log.debug("Scored {} examples ", totalCount); } - return ret; + return ret.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java index 240aeb01e..3fdc7fd1c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -37,35 +36,23 @@ import java.util.Iterator; import java.util.List; /** Function used to score a DataSet using a ComputationGraph */ -public class ScoreFlatMapFunctionCGDataSet - extends BaseFlatMapFunctionAdaptee, Tuple2> { - - public ScoreFlatMapFunctionCGDataSet(String json, Broadcast params, int minibatchSize) { - super(new ScoreFlatMapFunctionCGDataSetAdapter(json, params, minibatchSize)); - } -} - - -/** Function used to score a DataSet using a ComputationGraph */ -class ScoreFlatMapFunctionCGDataSetAdapter - implements FlatMapFunctionAdapter, Tuple2> { - +public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGDataSet.class); private String json; private Broadcast params; private int minibatchSize; - public ScoreFlatMapFunctionCGDataSetAdapter(String json, Broadcast params, int minibatchSize) { + public ScoreFlatMapFunctionCGDataSet(String json, Broadcast params, int minibatchSize) { this.json = json; this.params = params; this.minibatchSize = minibatchSize; } @Override - public Iterable> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)); + return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); } DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -90,6 +77,6 @@ class ScoreFlatMapFunctionCGDataSetAdapter Nd4j.getExecutioner().commit(); - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java index 91072942d..bf9e3f596 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java @@ -16,9 +16,8 @@ package org.deeplearning4j.spark.impl.graph.scoring; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -29,7 +28,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; -import lombok.val; import java.util.ArrayList; import java.util.Collections; @@ -37,18 +35,7 @@ import java.util.Iterator; import java.util.List; /** Function used to score a MultiDataSet using a given ComputationGraph */ -public class ScoreFlatMapFunctionCGMultiDataSet - extends BaseFlatMapFunctionAdaptee, Tuple2> { - - public ScoreFlatMapFunctionCGMultiDataSet(String json, Broadcast params, int minibatchSize) { - super(new ScoreFlatMapFunctionCGMultiDataSetAdapter(json, params, minibatchSize)); - } -} - - -/** Function used to score a MultiDataSet using a given ComputationGraph */ -class ScoreFlatMapFunctionCGMultiDataSetAdapter - implements FlatMapFunctionAdapter, Tuple2> { +public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGMultiDataSet.class); private String json; @@ -56,16 +43,16 @@ class ScoreFlatMapFunctionCGMultiDataSetAdapter private int minibatchSize; - public ScoreFlatMapFunctionCGMultiDataSetAdapter(String json, Broadcast params, int minibatchSize) { + public ScoreFlatMapFunctionCGMultiDataSet(String json, Broadcast params, int minibatchSize) { this.json = json; this.params = params; this.minibatchSize = minibatchSize; } @Override - public Iterable> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)); + return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); } MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -91,6 +78,6 @@ class ScoreFlatMapFunctionCGMultiDataSetAdapter Nd4j.getExecutioner().commit(); - return out; + return out.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java index f95cc46f2..0a33fb995 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java @@ -17,9 +17,8 @@ package org.deeplearning4j.spark.impl.multilayer.evaluation; import lombok.extern.slf4j.Slf4j; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.spark.impl.evaluation.EvaluationRunner; import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -36,25 +35,7 @@ import java.util.concurrent.Future; * * @author Alex Black */ -public class IEvaluateFlatMapFunction - extends BaseFlatMapFunctionAdaptee, T[]> { - - public IEvaluateFlatMapFunction(boolean isCompGraph, Broadcast json, Broadcast params, - int evalNumWorkers, int evalBatchSize, T... evaluations) { - super(new IEvaluateFlatMapFunctionAdapter<>(isCompGraph, json, params, evalNumWorkers, evalBatchSize, evaluations)); - } -} - - -/** - * Function to evaluate data (using an IEvaluation instance), in a distributed manner - * Flat map function used to batch examples for computational efficiency + reduce number of IEvaluation objects returned - * for network efficiency. - * - * @author Alex Black - */ -@Slf4j -class IEvaluateFlatMapFunctionAdapter implements FlatMapFunctionAdapter, T[]> { +public class IEvaluateFlatMapFunction implements FlatMapFunction, T[]> { protected boolean isCompGraph; protected Broadcast json; @@ -70,7 +51,7 @@ class IEvaluateFlatMapFunctionAdapter implements FlatMapF * this. Used to avoid doing too many at once (and hence memory issues) * @param evaluations Initial evaulation instance (i.e., empty Evaluation or RegressionEvaluation instance) */ - public IEvaluateFlatMapFunctionAdapter(boolean isCompGraph, Broadcast json, Broadcast params, + public IEvaluateFlatMapFunction(boolean isCompGraph, Broadcast json, Broadcast params, int evalNumWorkers, int evalBatchSize, T[] evaluations) { this.isCompGraph = isCompGraph; this.json = json; @@ -81,9 +62,9 @@ class IEvaluateFlatMapFunctionAdapter implements FlatMapF } @Override - public Iterable call(Iterator dataSetIterator) throws Exception { + public Iterator call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } Future f = EvaluationRunner.getInstance().execute( @@ -91,9 +72,9 @@ class IEvaluateFlatMapFunctionAdapter implements FlatMapF IEvaluation[] result = f.get(); if(result == null){ - return Collections.emptyList(); + return Collections.emptyIterator(); } else { - return Collections.singletonList((T[])result); + return Collections.singletonList((T[])result).iterator(); } } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java index 2bdf926b6..03e4e55cf 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java @@ -16,17 +16,15 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSetUtil; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; -import org.nd4j.linalg.util.DataSetUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; @@ -44,23 +42,7 @@ import java.util.List; * @author Alex Black */ public class FeedForwardWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>>, K, INDArray> { - - public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { - super(new FeedForwardWithKeyFunctionAdapter(params, jsonConfig, batchSize)); - } -} - - -/** - * Function to feed-forward examples, and get the network output (for example, class probabilities). - * A key value is used to keey track of which output corresponds to which input. - * - * @param Type of key, associated with each example. Used to keep track of which output belongs to which input example - * @author Alex Black - */ -class FeedForwardWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>>, Tuple2> { + implements PairFlatMapFunction>>, K, INDArray> { protected static Logger log = LoggerFactory.getLogger(FeedForwardWithKeyFunction.class); @@ -73,7 +55,7 @@ class FeedForwardWithKeyFunctionAdapter * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use for forward pass (use > 1 for efficiency) */ - public FeedForwardWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, int batchSize) { + public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.batchSize = batchSize; @@ -81,9 +63,9 @@ class FeedForwardWithKeyFunctionAdapter @Override - public Iterable> call(Iterator>> iterator) throws Exception { + public Iterator> call(Iterator>> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); @@ -129,7 +111,7 @@ class FeedForwardWithKeyFunctionAdapter } if (tupleCount == 0) { - return Collections.emptyList(); + return Collections.emptyIterator(); } List> output = new ArrayList<>(tupleCount); @@ -185,7 +167,7 @@ class FeedForwardWithKeyFunctionAdapter Nd4j.getExecutioner().commit(); - return output; + return output.iterator(); } private INDArray getSubset(int exampleStart, int exampleEnd, INDArray from) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java index 0b3383381..4142750d0 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java @@ -16,11 +16,11 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; +import org.apache.spark.api.java.function.DoubleFlatMapFunction; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.util.BaseDoubleFlatMapFunctionAdaptee; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -39,23 +39,7 @@ import java.util.List; * @author Alex Black * @see ScoreExamplesWithKeyFunction */ -public class ScoreExamplesFunction extends BaseDoubleFlatMapFunctionAdaptee> { - - public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/**Function to score examples individually. Note that scoring is batched for computational efficiency.
- * This is essentially a Spark implementation of the {@link MultiLayerNetwork#scoreExamples(DataSet, boolean)} method
- * Note: This method returns a score for each example, but the association between examples and scores is lost. In - * cases where we need to know the score for particular examples, use {@link ScoreExamplesWithKeyFunction} - * @author Alex Black - * @see ScoreExamplesWithKeyFunction - */ -class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter, Double> { +public class ScoreExamplesFunction implements DoubleFlatMapFunction> { protected static Logger log = LoggerFactory.getLogger(ScoreExamplesFunction.class); @@ -64,7 +48,7 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter params, Broadcast jsonConfig, + public ScoreExamplesFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; @@ -74,9 +58,9 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter call(Iterator iterator) throws Exception { + public Iterator call(Iterator iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); @@ -119,6 +103,6 @@ class ScoreExamplesFunctionAdapter implements FlatMapFunctionAdapter - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { - super(new ScoreExamplesWithKeyFunctionAdapter(params, jsonConfig, addRegularizationTerms, batchSize)); - } -} - - -/** - * Function to score examples individually, where each example is associated with a particular key
- * Note that scoring is batched for computational efficiency.
- * This is the Spark implementation of t he {@link MultiLayerNetwork#scoreExamples(DataSet, boolean)} method
- * Note: The DataSet objects passed in must have exactly one example in them (otherwise: can't have a 1:1 association - * between keys and data sets to score) - * - * @param Type of key, associated with each example. Used to keep track of which score belongs to which example - * @author Alex Black - * @see ScoreExamplesFunction - */ -class ScoreExamplesWithKeyFunctionAdapter - implements FlatMapFunctionAdapter>, Tuple2> { - - protected static Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class); +@Slf4j +public class ScoreExamplesWithKeyFunction implements PairFlatMapFunction>, K, Double> { private final Broadcast params; private final Broadcast jsonConfig; @@ -81,8 +56,7 @@ class ScoreExamplesWithKeyFunctionAdapter * @param addRegularizationTerms if true: add regularization terms (L1, L2) to the score * @param batchSize Batch size to use when scoring */ - public ScoreExamplesWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, - boolean addRegularizationTerms, int batchSize) { + public ScoreExamplesWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean addRegularizationTerms, int batchSize) { this.params = params; this.jsonConfig = jsonConfig; this.addRegularization = addRegularizationTerms; @@ -91,9 +65,9 @@ class ScoreExamplesWithKeyFunctionAdapter @Override - public Iterable> call(Iterator> iterator) throws Exception { + public Iterator> call(Iterator> iterator) throws Exception { if (!iterator.hasNext()) { - return Collections.emptyList(); + return Collections.emptyIterator(); } MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); @@ -143,6 +117,6 @@ class ScoreExamplesWithKeyFunctionAdapter log.debug("Scored {} examples ", totalCount); } - return ret; + return ret.iterator(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java index 480f1dcd2..8063ba8e3 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java @@ -16,9 +16,11 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import org.datavec.spark.transform.BaseFlatMapFunctionAdaptee; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -26,43 +28,25 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import scala.Tuple2; -import lombok.val; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.List; -public class ScoreFlatMapFunction extends BaseFlatMapFunctionAdaptee, Tuple2> { - - public ScoreFlatMapFunction(String json, Broadcast params, int minibatchSize) { - super(new ScoreFlatMapFunctionAdapter(json, params, minibatchSize)); - } - -} - - -class ScoreFlatMapFunctionAdapter implements FlatMapFunctionAdapter, Tuple2> { - - private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunction.class); +@Slf4j +@AllArgsConstructor +public class ScoreFlatMapFunction implements FlatMapFunction, Tuple2> { private String json; private Broadcast params; private int minibatchSize; - public ScoreFlatMapFunctionAdapter(String json, Broadcast params, int minibatchSize) { - this.json = json; - this.params = params; - this.minibatchSize = minibatchSize; - } - @Override - public Iterable> call(Iterator dataSetIterator) throws Exception { + public Iterator> call(Iterator dataSetIterator) throws Exception { if (!dataSetIterator.hasNext()) { - return Collections.singletonList(new Tuple2<>(0, 0.0)); + return Collections.singletonList(new Tuple2<>(0, 0.0)).iterator(); } DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate @@ -87,6 +71,6 @@ class ScoreFlatMapFunctionAdapter implements FlatMapFunctionAdapter - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, - int batchSize) { - super(new VaeReconstructionErrorWithKeyFunctionAdapter(params, jsonConfig, batchSize)); - } -} - - -/** - * Function to calculate the reconstruction error for a variational autoencoder, that is the first layer in a - * MultiLayerNetwork.
- * Note that the VAE must be using a loss function, not a {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution}
- * Also note that scoring is batched for computational efficiency.
- * - * @author Alex Black - * @see VaeReconstructionProbWithKeyFunction - */ -class VaeReconstructionErrorWithKeyFunctionAdapter extends BaseVaeScoreWithKeyFunctionAdapter { +public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction { /** * @param params MultiLayerNetwork parameters * @param jsonConfig MultiLayerConfiguration, as json * @param batchSize Batch size to use when scoring */ - public VaeReconstructionErrorWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { super(params, jsonConfig, batchSize); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java index 7bba68b28..e8fc8416f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java @@ -21,12 +21,8 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunctionAdapter; -import org.deeplearning4j.spark.util.BasePairFlatMapFunctionAdaptee; +import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; import org.nd4j.linalg.api.ndarray.INDArray; -import scala.Tuple2; - -import java.util.Iterator; /** @@ -36,25 +32,7 @@ import java.util.Iterator; * * @author Alex Black */ -public class VaeReconstructionProbWithKeyFunction - extends BasePairFlatMapFunctionAdaptee>, K, Double> { - - public VaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, - boolean useLogProbability, int batchSize, int numSamples) { - super(new VaeReconstructionProbWithKeyFunctionAdapter(params, jsonConfig, useLogProbability, batchSize, - numSamples)); - } -} - - -/** - * Function to calculate the reconstruction probability for a variational autoencoder, that is the first layer in a - * MultiLayerNetwork.
- * Note that scoring is batched for computational efficiency.
- * - * @author Alex Black - */ -class VaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeReconstructionProbWithKeyFunctionAdapter { +public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructionProbWithKeyFunction { /** @@ -64,7 +42,7 @@ class VaeReconstructionProbWithKeyFunctionAdapter extends BaseVaeReconstructi * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} */ - public VaeReconstructionProbWithKeyFunctionAdapter(Broadcast params, Broadcast jsonConfig, + public VaeReconstructionProbWithKeyFunction(Broadcast params, Broadcast jsonConfig, boolean useLogProbability, int batchSize, int numSamples) { super(params, jsonConfig, useLogProbability, batchSize, numSamples); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index 30e6b395f..b6d5654ec 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -67,7 +67,7 @@ import java.io.IOException; import java.io.OutputStream; import java.util.*; -import static com.google.common.base.Preconditions.checkArgument; +import static org.nd4j.shade.guava.base.Preconditions.checkArgument; /** * ParameterAveragingTrainingMaster: A {@link TrainingMaster} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java deleted file mode 100644 index 5f9d25547..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.DoubleFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -/** - * DoubleFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DoubleFlatMapFunction - * - */ -public class BaseDoubleFlatMapFunctionAdaptee implements DoubleFlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseDoubleFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterable call(T t) throws Exception { - return adapter.call(t); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java deleted file mode 100644 index b58cab202..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-1/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import scala.Tuple2; - -/** - * PairFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to PairFlatMapFunction - * - */ -public class BasePairFlatMapFunctionAdaptee implements PairFlatMapFunction { - - protected final FlatMapFunctionAdapter> adapter; - - public BasePairFlatMapFunctionAdaptee(FlatMapFunctionAdapter> adapter) { - this.adapter = adapter; - } - - @Override - public Iterable> call(T t) throws Exception { - return adapter.call(t); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java deleted file mode 100644 index 49a05231b..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BaseDoubleFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.DoubleFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; - -import java.util.Iterator; - -/** - * DoubleFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to DoubleFlatMapFunction - * - */ -public class BaseDoubleFlatMapFunctionAdaptee implements DoubleFlatMapFunction { - - protected final FlatMapFunctionAdapter adapter; - - public BaseDoubleFlatMapFunctionAdaptee(FlatMapFunctionAdapter adapter) { - this.adapter = adapter; - } - - @Override - public Iterator call(T t) throws Exception { - return adapter.call(t).iterator(); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java deleted file mode 100644 index b28a8dcc2..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/spark-2/java/org/deeplearning4j/spark/util/BasePairFlatMapFunctionAdaptee.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.spark.util; - -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.datavec.spark.functions.FlatMapFunctionAdapter; -import scala.Tuple2; - -import java.util.Iterator; - -/** - * PairFlatMapFunction adapter to hide incompatibilities between Spark 1.x and Spark 2.x - * - * This class should be used instead of direct referral to PairFlatMapFunction - * - */ -public class BasePairFlatMapFunctionAdaptee implements PairFlatMapFunction { - - protected final FlatMapFunctionAdapter> adapter; - - public BasePairFlatMapFunctionAdaptee(FlatMapFunctionAdapter> adapter) { - this.adapter = adapter; - } - - @Override - public Iterator> call(T t) throws Exception { - return adapter.call(t).iterator(); - } -} diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml index 1919fcb06..f753fefae 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml @@ -36,16 +36,9 @@ UTF-8 UTF-8 - - 2.1.0 - 2 1.0.0_spark_2-SNAPSHOT - 2.1.0 2.11.12 @@ -194,21 +187,6 @@ jaxb-impl ${jaxb.version}
- - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-remote_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - io.netty netty @@ -280,36 +258,6 @@
- - spark-2 - - - spark.major.version - 2 - - - - - com.typesafe.akka - akka-remote_2.11 - 2.3.11 - - - - - cdh5 - - - org.apache.hadoop - https://repository.cloudera.com/artifactory/cloudera-repos/ - - - - 2.0.0-cdh4.6.0 - 1.2.0-cdh5.3.0 - - - test-nd4j-native diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml index c1c9a4664..1b4f33c1e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/pom.xml @@ -100,11 +100,27 @@ test-nd4j-native - - - test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-native + ${project.version} + test + + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + @@ -114,26 +130,11 @@ ${project.version}
- - com.google.guava - guava - ${guava.version} - com.google.protobuf protobuf-java ${google.protobuf.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - javax.ws.rs javax.ws.rs-api @@ -210,76 +211,6 @@ leveldbjni-all ${leveldb.version} - - com.typesafe.akka - akka-contrib_2.11 - ${akka.version} - - - - - - - com.fasterxml.jackson.core - jackson-core - ${spark.jackson.version} - - - - com.fasterxml.jackson.core - jackson-databind - ${spark.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark.jackson.version} - - - - - com.fasterxml.jackson.module - jackson-module-scala_2.11 - ${spark.jackson.version} - - - com.google.code.findbugs - jsr305 - - - - - - - com.fasterxml.jackson.datatype - jackson-datatype-jdk8 - ${spark.jackson.version} - - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${spark.jackson.version} - - - - - - - com.typesafe - config - ${typesafe.config.version} - - - - com.typesafe.akka - akka-cluster_2.11 - ${akka.version} - com.beust @@ -306,14 +237,6 @@ test - - org.nd4j - nd4j-native - ${project.version} - test - - - org.webjars.npm @@ -722,6 +645,4 @@
- - diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java index a8a499ec6..f694291c8 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/FunctionType.java @@ -17,10 +17,17 @@ package org.deeplearning4j.ui.api; /** - * Enumeration for the type of function. Mainly used in specifying {@link Route} instances + * Enumeration for the type of function. Mainly used in specifying {@link Route} instances
+ * Supplier: No args
+ * Function: 1 arg
+ * BiFunction: 2 args
+ * Function3: 3 args
+ * Request0Function: Supplier + request, no args (as Function)
+ * Request1Function: Supplier + request + 1 args (as BiFunction)
* * @author Alex Black */ public enum FunctionType { - Supplier, Function, BiFunction, Function3 + Supplier, Function, BiFunction, Function3, + Request0Function, Request1Function } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java index 0f0a731d8..f2dd1e017 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/api/Route.java @@ -18,6 +18,7 @@ package org.deeplearning4j.ui.api; import lombok.AllArgsConstructor; import lombok.Data; +import play.mvc.Http; import play.mvc.Result; import java.util.function.BiFunction; @@ -38,17 +39,27 @@ public class Route { private final Supplier supplier; private final Function function; private final BiFunction function2; + private final Function request0Function; + private final BiFunction request1Function; public Route(String route, HttpMethod method, FunctionType functionType, Supplier supplier) { - this(route, method, functionType, supplier, null, null); + this(route, method, functionType, supplier, null, null, null, null); } public Route(String route, HttpMethod method, FunctionType functionType, Function function) { - this(route, method, functionType, null, function, null); + this(route, method, functionType, null, function, null, null, null); + } + + public static Route request0Function(String route, HttpMethod httpMethod, Function function){ + return new Route(route, httpMethod, FunctionType.Request0Function, null, null, null, function, null); + } + + public static Route request1Function(String route, HttpMethod httpMethod, BiFunction function){ + return new Route(route, httpMethod, FunctionType.Request1Function, null, null, null, null, function); } public Route(String route, HttpMethod method, FunctionType functionType, BiFunction function) { - this(route, method, functionType, null, null, function); + this(route, method, functionType, null, null, function, null, null); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java index fe92756cc..0886af8a9 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java @@ -24,6 +24,7 @@ import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.i18n.I18NResource; +import play.libs.Files; import play.libs.Json; import play.mvc.Http; import play.mvc.Result; @@ -31,9 +32,9 @@ import play.mvc.Results; import java.io.File; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.*; -import static play.mvc.Controller.request; import static play.mvc.Results.badRequest; import static play.mvc.Results.ok; @@ -63,8 +64,8 @@ public class TsneModule implements UIModule { () -> ok(org.deeplearning4j.ui.views.html.tsne.Tsne.apply())); Route r2 = new Route("/tsne/sessions", HttpMethod.GET, FunctionType.Supplier, this::listSessions); Route r3 = new Route("/tsne/coords/:sid", HttpMethod.GET, FunctionType.Function, this::getCoords); - Route r4 = new Route("/tsne/upload", HttpMethod.POST, FunctionType.Supplier, this::uploadFile); - Route r5 = new Route("/tsne/post/:sid", HttpMethod.POST, FunctionType.Function, this::postFile); + Route r4 = Route.request0Function("/tsne/upload", HttpMethod.POST, this::uploadFile); + Route r5 = Route.request1Function("/tsne/post/:sid", HttpMethod.POST, this::postFile); return Arrays.asList(r1, r2, r3, r4, r5); } @@ -106,22 +107,22 @@ public class TsneModule implements UIModule { } } - private Result uploadFile() { - Http.MultipartFormData body = request().body().asMultipartFormData(); - List fileParts = body.getFiles(); + private Result uploadFile(Http.Request request) { + Http.MultipartFormData body = request.body().asMultipartFormData(); + List> fileParts = body.getFiles(); if (fileParts.isEmpty()) { return badRequest("No file uploaded"); } - Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); + Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); String fileName = uploadedFile.getFilename(); String contentType = uploadedFile.getContentType(); - File file = uploadedFile.getFile(); + File file = uploadedFile.getRef().path().toFile(); try { - uploadedFileLines = FileUtils.readLines(file); + uploadedFileLines = FileUtils.readLines(file, StandardCharsets.UTF_8); } catch (IOException e) { return badRequest("Could not read from uploaded file"); } @@ -129,21 +130,21 @@ public class TsneModule implements UIModule { return ok("File uploaded: " + fileName + ", " + contentType + ", " + file); } - private Result postFile(String sid) { + private Result postFile(Http.Request request, String sid) { // System.out.println("POST FILE CALLED: " + sid); - Http.MultipartFormData body = request().body().asMultipartFormData(); - List fileParts = body.getFiles(); + Http.MultipartFormData body = request.body().asMultipartFormData(); + List> fileParts = body.getFiles(); if (fileParts.isEmpty()) { // System.out.println("**** NO FILE ****"); return badRequest("No file uploaded"); } - Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); + Http.MultipartFormData.FilePart uploadedFile = fileParts.get(0); String fileName = uploadedFile.getFilename(); String contentType = uploadedFile.getContentType(); - File file = uploadedFile.getFile(); + File file = uploadedFile.getRef().path().toFile(); List lines; try { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java index 747284e7e..c0ebe4ed8 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java @@ -37,18 +37,16 @@ import org.deeplearning4j.ui.module.defaultModule.DefaultModule; import org.deeplearning4j.ui.module.remote.RemoteReceiverModule; import org.deeplearning4j.ui.module.train.TrainModule; import org.deeplearning4j.ui.module.tsne.TsneModule; -import org.deeplearning4j.ui.play.misc.FunctionUtil; import org.deeplearning4j.ui.play.staticroutes.Assets; -import org.deeplearning4j.ui.play.staticroutes.I18NRoute; -import org.deeplearning4j.ui.play.staticroutes.MultiSessionI18NRoute; import org.deeplearning4j.ui.storage.FileStatsStorage; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.storage.impl.QueueStatsStorageListener; import org.deeplearning4j.util.DL4JFileUtils; import org.nd4j.linalg.function.Function; import org.nd4j.linalg.primitives.Pair; +import play.BuiltInComponents; import play.Mode; -import play.api.routing.Router; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; @@ -60,6 +58,8 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; +import static play.mvc.Results.ok; + /** * A UI server based on the Play framework @@ -166,63 +166,6 @@ public class PlayUIServer extends UIServer { System.exit(1); } - RoutingDsl routingDsl = new RoutingDsl(); - - //Set up index page and assets routing - //The definitions and FunctionUtil may look a bit weird here... this is used to translate implementation independent - // definitions (i.e., Java Supplier, Function etc interfaces) to the Play-specific versions - //This way, routing is not directly dependent ot Play API. Furthermore, Play 2.5 switches to using these Java interfaces - // anyway; thus switching 2.5 should be as simple as removing the FunctionUtil calls... - if (multiSession) { - routingDsl.GET("/setlang/:sessionId/:to").routeTo(FunctionUtil.biFunction(new MultiSessionI18NRoute())); - } else { - routingDsl.GET("/setlang/:to").routeTo(FunctionUtil.function(new I18NRoute())); - } - routingDsl.GET("/assets/*file").routeTo(FunctionUtil.function(new Assets(ASSETS_ROOT_DIRECTORY))); - - uiModules.add(new DefaultModule(multiSession)); //For: navigation page "/" - uiModules.add(new TrainModule(multiSession, statsStorageLoader, this::getAddress)); - uiModules.add(new ConvolutionalListenerModule()); - uiModules.add(new TsneModule()); - uiModules.add(new SameDiffModule()); - remoteReceiverModule = new RemoteReceiverModule(); - uiModules.add(remoteReceiverModule); - - //Check service loader mechanism (Arbiter UI, etc) for modules - modulesViaServiceLoader(uiModules); - - for (UIModule m : uiModules) { - List routes = m.getRoutes(); - for (Route r : routes) { - RoutingDsl.PathPatternMatcher ppm = routingDsl.match(r.getHttpMethod().name(), r.getRoute()); - switch (r.getFunctionType()) { - case Supplier: - ppm.routeTo(FunctionUtil.function0(r.getSupplier())); - break; - case Function: - ppm.routeTo(FunctionUtil.function(r.getFunction())); - break; - case BiFunction: - ppm.routeTo(FunctionUtil.biFunction(r.getFunction2())); - break; - case Function3: - default: - throw new RuntimeException("Not yet implemented"); - } - } - - //Determine which type IDs this module wants to receive: - List typeIDs = m.getCallbackTypeIDs(); - for (String typeID : typeIDs) { - List list = typeIDModuleMap.get(typeID); - if (list == null) { - list = Collections.synchronizedList(new ArrayList<>()); - typeIDModuleMap.put(typeID, list); - } - list.add(m); - } - } - String portProperty = System.getProperty(DL4JSystemProperties.UI_SERVER_PORT_PROPERTY); if (portProperty != null) { try { @@ -233,6 +176,7 @@ public class PlayUIServer extends UIServer { } } + //Set play secret key, if required //http://www.playframework.com/documentation/latest/ApplicationSecret String crypto = System.getProperty("play.crypto.secret"); @@ -245,9 +189,9 @@ public class PlayUIServer extends UIServer { System.setProperty("play.crypto.secret", base64); } - Router router = routingDsl.build(); + try { - server = Server.forRouter(router, Mode.PROD, port); + server = Server.forRouter(Mode.PROD, port, this::createRouter); } catch (Throwable e){ if(e.getMessage().contains("'play.crypto.provider")){ //Usual cause: user's uber-jar does not include application.conf @@ -284,6 +228,79 @@ public class PlayUIServer extends UIServer { setStopped(false); } + protected Router createRouter(BuiltInComponents builtInComponents){ + RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); + + //Set up index page and assets routing + //The definitions and FunctionUtil may look a bit weird here... this is used to translate implementation independent + // definitions (i.e., Java Supplier, Function etc interfaces) to the Play-specific versions + //This way, routing is not directly dependent ot Play API. Furthermore, Play 2.5 switches to using these Java interfaces + // anyway; thus switching 2.5 should be as simple as removing the FunctionUtil calls... + if (multiSession) { + routingDsl.GET("/setlang/:sessionId/:to").routingTo((request, sid, to) -> { + I18NProvider.getInstance(sid.toString()).setDefaultLanguage(to.toString()); + return ok(); + }); + } else { + routingDsl.GET("/setlang/:to").routingTo((request, to) -> { + I18NProvider.getInstance().setDefaultLanguage(to.toString()); + return ok(); + }); + } + routingDsl.GET("/assets/*file").routingTo((request, file) -> Assets.assetRequest(ASSETS_ROOT_DIRECTORY, file.toString())); + + uiModules.add(new DefaultModule(multiSession)); //For: navigation page "/" + uiModules.add(new TrainModule(multiSession, statsStorageLoader, this::getAddress)); + uiModules.add(new ConvolutionalListenerModule()); + uiModules.add(new TsneModule()); + uiModules.add(new SameDiffModule()); + remoteReceiverModule = new RemoteReceiverModule(); + uiModules.add(remoteReceiverModule); + + //Check service loader mechanism (Arbiter UI, etc) for modules + modulesViaServiceLoader(uiModules); + + for (UIModule m : uiModules) { + List routes = m.getRoutes(); + for (Route r : routes) { + RoutingDsl.PathPatternMatcher ppm = routingDsl.match(r.getHttpMethod().name(), r.getRoute()); + switch (r.getFunctionType()) { + case Supplier: + ppm.routingTo(request -> r.getSupplier().get()); + break; + case Function: + ppm.routingTo((request, arg) -> r.getFunction().apply(arg.toString())); + break; + case BiFunction: + ppm.routingTo((request, arg0, arg1) -> r.getFunction2().apply(arg0.toString(), arg1.toString())); + break; + case Request0Function: + ppm.routingTo(request -> r.getRequest0Function().apply(request)); + break; + case Request1Function: + ppm.routingTo((request, arg0) -> r.getRequest1Function().apply(request, arg0.toString())); + break; + case Function3: + default: + throw new RuntimeException("Not yet implemented"); + } + } + + //Determine which type IDs this module wants to receive: + List typeIDs = m.getCallbackTypeIDs(); + for (String typeID : typeIDs) { + List list = typeIDModuleMap.get(typeID); + if (list == null) { + list = Collections.synchronizedList(new ArrayList<>()); + typeIDModuleMap.put(typeID, list); + } + list.add(m); + } + } + Router router = routingDsl.build(); + return router; + } + @Override public String getAddress() { String addr = server.mainAddress().toString(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java index dd100b3d1..76d6e7361 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/staticroutes/Assets.java @@ -16,18 +16,17 @@ package org.deeplearning4j.ui.play.staticroutes; -import com.google.common.net.HttpHeaders; +import org.nd4j.shade.guava.net.HttpHeaders; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FilenameUtils; import org.nd4j.linalg.io.ClassPathResource; -import play.api.libs.MimeTypes; import play.mvc.Result; +import play.mvc.StaticFileMimeTypes; import java.io.InputStream; -import java.util.function.Function; +import java.util.Optional; -import static play.mvc.Http.Context.Implicit.response; import static play.mvc.Results.ok; /** @@ -37,11 +36,9 @@ import static play.mvc.Results.ok; */ @AllArgsConstructor @Slf4j -public class Assets implements Function { - private final String assetsRootDirectory; +public class Assets { - @Override - public Result apply(String s) { + public static Result assetRequest(String assetsRootDirectory, String s) { String fullPath; if(s.startsWith("webjars/")){ @@ -60,15 +57,12 @@ public class Assets implements Function { String fileName = FilenameUtils.getName(fullPath); - response().setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\""); - scala.Option contentType = MimeTypes.forFileName(fileName); + Optional contentType = StaticFileMimeTypes.fileMimeTypes().forFileName(fileName); String ct; - if (contentType.isDefined()) { - ct = contentType.get(); - } else { - ct = "application/octet-stream"; - } + ct = contentType.orElse("application/octet-stream"); - return ok(inputStream).as(ct); + return ok(inputStream) + .withHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"" + fileName + "\"") + .as(ct); } } diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index be75282e5..e2890fc61 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -16,7 +16,7 @@ package org.deeplearning4j.integration; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index 9440a59ae..b1de34a2c 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -17,8 +17,8 @@ package org.deeplearning4j.integration; -import com.google.common.collect.ImmutableSet; -import com.google.common.reflect.ClassPath; +import org.nd4j.shade.guava.collect.ImmutableSet; +import org.nd4j.shade.guava.reflect.ClassPath; import org.deeplearning4j.integration.util.CountingMultiDataSetIterator; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java index 463734b82..a70d8dd2f 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/RNNTestCases.java @@ -16,7 +16,7 @@ package org.deeplearning4j.integration.testcases; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.testcases.misc.CharacterIterator; import org.deeplearning4j.integration.testcases.misc.CompositeMultiDataSetPreProcessor; diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index 954d8341d..a139c4f44 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -174,11 +174,6 @@ slf4j-api ${slf4j.version}
- - com.google.guava - guava - ${guava.version} - junit junit diff --git a/docs/deeplearning4j/templates/benchmark.md b/docs/deeplearning4j/templates/benchmark.md index 330ff99a6..93e30fda9 100644 --- a/docs/deeplearning4j/templates/benchmark.md +++ b/docs/deeplearning4j/templates/benchmark.md @@ -45,7 +45,7 @@ Ideally, these should be excluded from any timing/performance results you report For example: what BLAS implementation (MKL, OpenBLAS, etc)? If you are using CUDA, are you using CuDNN? ND4J and DL4J can use these libraries (MKL, CuDNN) when they are available - but are not always available by default. If they are not made available, performance can be lower - sometimes considerably. -This is especially important when comparing results between libraries: for example, if you compared two libraries (one using OpenBLAS, another using MLK) your results may simply reflect the performance differences it the BLAS library being used - and not the performance oth the libraries being tested. Similarly, one library with CuDNN and another without CuDNN may simply reflect the performance benefit of using CuDNN. +This is especially important when comparing results between libraries: for example, if you compared two libraries (one using OpenBLAS, another using MKL) your results may simply reflect the performance differences it the BLAS library being used - and not the performance oth the libraries being tested. Similarly, one library with CuDNN and another without CuDNN may simply reflect the performance benefit of using CuDNN. 3. How are things configured? diff --git a/gym-java-client/pom.xml b/gym-java-client/pom.xml index 1b7488c81..fc82d01b2 100644 --- a/gym-java-client/pom.xml +++ b/gym-java-client/pom.xml @@ -320,4 +320,18 @@ + + + + nd4j-backend + + + libnd4j.cuda + + + + nd4j-cuda-${libnd4j.cuda} + + + diff --git a/jumpy/pom.xml b/jumpy/pom.xml index e298d1ce9..c3132e846 100644 --- a/jumpy/pom.xml +++ b/jumpy/pom.xml @@ -44,12 +44,13 @@ false 0.2.4 + nd4j-native org.nd4j - nd4j-native + ${nd4j.backend} ${dl4j.version} @@ -179,4 +180,18 @@ + + + + nd4j-backend + + + libnd4j.cuda + + + + nd4j-cuda-${libnd4j.cuda} + + + diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 8e940bedb..add8960a3 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -341,7 +341,7 @@ elseif(CPU_BLAS) endif() endif() - if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(APPLE)) + if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(APPLE) AND NOT(WIN32)) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic") SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") endif() diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index 3237e5033..3ef3716b3 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -659,6 +660,8 @@ namespace nd4j { void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + /** * apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this) * tad - array to broadcast @@ -672,6 +675,9 @@ namespace nd4j { void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + + /** * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting * other - input array @@ -692,6 +698,9 @@ namespace nd4j { void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + + /** * apply a scalar operation to an array * scalar - input scalar @@ -704,6 +713,9 @@ namespace nd4j { template void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + template + void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + /** * apply a scalar operation to an array * scalar - input array which is simple scalar @@ -714,6 +726,7 @@ namespace nd4j { void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; #if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) template diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index fdbcae49f..82427f9b9 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -2663,6 +2663,88 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* delete pTarget; } + + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + if (isS()) + throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); + if(target == nullptr || other == nullptr) + throw std::runtime_error("NDArray::applyTrueBroadcast int method: target or other = nullptr !"); + + if (isEmpty() || other->isEmpty()) + return; + + NDArray::prepareSpecialUse({target}, {this, other}); + + if (isScalar()) { + NDArray temp(target->_shapeInfo, dataType(), false, getContext()); + temp.assign(this); + temp.applyPairwiseTransform(op.p, other, target, extraArgs); + return; + } + if (other->isScalar()) { + this->applyScalarArr(op.s, other, target, extraArgs); + return; + } + + const NDArray* min(other); + const NDArray* max(this); + + if(this->rankOf() < other->rankOf()) { + max = other; + min = this; + } + + if(checkTargetShape) { + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType()) + throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); + if(dataType() != other->dataType()) + throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); + } + + NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext()); + // check whether max array has to be tiled + if(!max->isSameShape(target)) { + // evaluate repeating dimensions for tile operation + std::vector repeatMax(max->rankOf()); + for(int i = 1; i <= max->rankOf(); ++i) + repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]); + max->tile(repeatMax, *pTarget); + } + else + pTarget->assign(max); + + // check whether min array has to be tiled + std::vector repeatMin(min->rankOf()); + int product = 1; + for(int i = min->rankOf(); i >=1 ; --i) { + repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]); + product *= repeatMin[i-1]; + } + + auto pMin = const_cast(min); + if(product != 1 ) + pMin = new NDArray(min->tile(repeatMin)); + + std::vector sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin); + + if(max == this) + pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs); + else + pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs); + + if(pMin != min) + delete pMin; + if(pTarget != target) + delete pTarget; + } + + + ////////////////////////////////////////////////////////////////////////// NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const { if (isEmpty() || other.isEmpty()) { @@ -2801,6 +2883,67 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector registerSpecialUse({result}, {this, other}); } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { + if (!isZ()) + throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); + if(isEmpty() || other->isEmpty()) { + if(!target->isEmpty()) + throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); + return; + } + + if (dimensions.empty()) + return; + + auto result = target == nullptr ? this : target; + + if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { + NDArray::prepareSpecialUse({result}, {this, other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {this, other}); + return; + } + + NDArray *min(nullptr), *max(nullptr); + if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + max = this; + min = const_cast(other); + } + else { + max = const_cast(other); + min = this; + } + + if(result->dataType() != dataType()) + throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); + if(!result->isSameShape(max)) + throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !"); + if(_dataType != other->_dataType) + throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) + std::sort(copy.begin(), copy.end()); + + Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size()); + if (tadLength != min->lengthOf()) + throw std::runtime_error("Tad length mismatch"); + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + + // TODO: eventually we want separate tads here + NDArray::prepareSpecialUse({result}, {this, other}); + if(max == this) + NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + else + NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({result}, {this, other}); + } + ////////////////////////////////////////////////////////////////////////// void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) { std::vector vec(dimensions); @@ -3043,6 +3186,22 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray * NDArray::registerSpecialUse({target}, {this, other}); } +//////////////////////////////////////////////////////////////////////// + void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{ + if (isS()) + throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); + if (other->lengthOf() != target->lengthOf()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); + if (!target->isZ()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); + if (dataType() != other->dataType()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); + + NDArray::prepareSpecialUse({target}, {this, other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); + NDArray::registerSpecialUse({target}, {this, other}); + } + ////////////////////////////////////////////////////////////////////////// void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { applyPairwiseTransform(op, &other, this, extraParams); @@ -3585,13 +3744,52 @@ template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_ template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); + + if (target == nullptr || target->dataType() != this->dataType()) + throw std::invalid_argument("NDArray::applyScalarArr int method: target is nullptr or has not bool type!"); + if (dataType() != scalar->dataType()) { + nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType()); + throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); + } + + NDArray::prepareSpecialUse({target}, {this, scalar}); + NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); + NDArray::registerSpecialUse({target}, {this, scalar}); + } + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const { + + NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); + applyScalarArr(op, &scalarArr, target, extraParams); + } + + template <> void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; + + //////////////////////////////////////////////////////////////////////// void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector& dimensions, const ExtraArguments *extraParams) const { if (isS()) throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); - if (target->dataType() != nd4j::DataType::INT64) - throw std::runtime_error("NDArray::applyIndexReduce operations return INT64"); + if (target->dataType() != nd4j::DataType::INT64 && target->dataType() != nd4j::DataType::INT32) + throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; diff --git a/libnd4j/blas/NativeOpExecutioner.h b/libnd4j/blas/NativeOpExecutioner.h index 534453a0f..cae7a4e56 100644 --- a/libnd4j/blas/NativeOpExecutioner.h +++ b/libnd4j/blas/NativeOpExecutioner.h @@ -189,6 +189,16 @@ static void execScalarBool(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dSscalarShapeInfo, void *extraParams, bool allowParallelism = true); +static void execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalar, Nd4jLong *hSscalarShapeInfo, + void *dScalar, Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism = true); + static void execScalar(nd4j::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, @@ -215,6 +225,20 @@ static void execScalarBool(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + static void execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *extraParams, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalars, Nd4jLong *hScalarShapeInfo, + void *dScalars, Nd4jLong *dScalarShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + /** * * @param opNum @@ -275,6 +299,30 @@ static void execScalarBool(nd4j::LaunchContext *lc, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static void execBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); + + static void execInverseBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *x, Nd4jLong *xShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *y, Nd4jLong *yShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *result, Nd4jLong *resultShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); /** * @@ -308,6 +356,16 @@ static void execScalarBool(nd4j::LaunchContext *lc, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams); + static void execPairwiseIntTransform(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams); + /** * * @param opNum diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 9ce90176f..9bca7bb10 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -79,6 +79,18 @@ bool verbose = false; extern "C" { +/** + * This function returns last error code stored, + * @return non-zero if something bad happened + */ +ND4J_EXPORT int lastErrorCode(); + +/** + * This function returns last error message, if last error code > 0 + * @return + */ +ND4J_EXPORT const char* lastErrorMessage(); + /** * * @param p @@ -557,38 +569,6 @@ ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); - -/** -* Append an input array -* to the end of a flat array -* in a particular order -* @param offset the offset of the array to start at -* @param order the order -* @param result the result array -* @param resultShapeInfo the shape info for te array -* @param input the input for the array -* @param inputShapeInfo the shape information for that array -*/ -ND4J_EXPORT void flatten( - Nd4jPointer *extraPointers, - int offset, - char order, - void *result, Nd4jLong *resultShapeInfo, - void *dresult, Nd4jLong *dresultShapeInfo, - void *input, Nd4jLong *inputShapeInfo, - void *dinput, Nd4jLong *dinputShapeInfo); - -ND4J_EXPORT void concat( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfo, - Nd4jPointer *ddata, Nd4jPointer *dinputShapeInfo, - void *result, Nd4jLong *resultShapeInfo, - void *dresult, Nd4jLong *dresultShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers); - - ND4J_EXPORT void specialConcat ( Nd4jPointer *extraPointers, int dimension, diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index e320b4f57..22fd9eca4 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -24,6 +24,10 @@ #include #include +#include +#include +#include + #include #include #include @@ -79,9 +83,10 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, int op #endif auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto hz = reinterpret_cast(hZ); - BUILD_SINGLE_SELECTOR(xType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES, INDEXING_TYPES); } //////////////////////////////////////////////////////////////////////// @@ -111,9 +116,10 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc, #endif auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); Nd4jLong* hz = reinterpret_cast(hZ); - BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); // BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES); } @@ -240,6 +246,66 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); } + + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); +} + +void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { + +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); +} + //////////////////////////////////////////////////////////////////////// /** * @@ -295,9 +361,42 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (xType != yType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); + + if (zType != nd4j::DataType::BOOL) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams) { +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), INTEGER_TYPES); +} + //////////////////////////////////////////////////////////////////////// /** * @@ -737,6 +836,64 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalar, Nd4jLong *hSscalarShapeInfo, + void *dScalar, Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism) { + +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), INTEGER_TYPES); +} + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *extraParams, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalars, Nd4jLong *hScalarShapeInfo, + void *dScalars, Nd4jLong *dScalarShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); +} + //////////////////////////////////////////////////////////////////////// /** * diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index f5d4996e4..86bc04fc4 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -102,8 +102,12 @@ void execIndexReduceScalar(Nd4jPointer *extraPointers, void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + NativeOpExecutioner::execIndexReduceScalar(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -125,31 +129,36 @@ void execIndexReduce(Nd4jPointer *extraPointers,int opNum, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); + auto hz = reinterpret_cast(hZ); - auto hz = reinterpret_cast(hZ); - - NativeOpExecutioner::execIndexReduce(nullptr, opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hz, - hZShapeInfo, - dZ, - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); + NativeOpExecutioner::execIndexReduce(nullptr, opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hz, + hZShapeInfo, + dZ, + dZShapeInfo, + dimension, + dimensionLength, + hTADShapeInfo, + hTADOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -175,31 +184,38 @@ void execBroadcast(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength); + auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); + auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, + dimensionLength); - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); - auto hTADOffsetsZ = tadPackZ.primaryOffsets(); + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); + auto hTADOffsetsZ = tadPackZ.primaryOffsets(); - NativeOpExecutioner::execBroadcast(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hY, - hYShapeInfo, - dY, - dYShapeInfo, - hZ, hZShapeInfo, - dZ, dZShapeInfo, - dimension, - dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); + NativeOpExecutioner::execBroadcast(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hY, + hYShapeInfo, + dY, + dYShapeInfo, + hZ, hZShapeInfo, + dZ, dZShapeInfo, + dimension, + dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execBroadcastBool(Nd4jPointer *extraPointers, @@ -212,31 +228,39 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); - auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, dimensionLength); + auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); + auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hZShapeInfo, dimension, + dimensionLength); - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); - auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); - auto hTADOffsetsZ = tadPackZ.primaryOffsets(); + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); + auto hTADShapeInfoZ = tadPackZ.primaryShapeInfo(); + auto hTADOffsetsZ = tadPackZ.primaryOffsets(); - NativeOpExecutioner::execBroadcastBool(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hY, - hYShapeInfo, - dY, - dYShapeInfo, - hZ, hZShapeInfo, - dZ, dZShapeInfo, - dimension, - dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); + NativeOpExecutioner::execBroadcastBool(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hY, + hYShapeInfo, + dY, + dYShapeInfo, + hZ, hZShapeInfo, + dZ, dZShapeInfo, + dimension, + dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, + hTADOffsetsZ); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -261,21 +285,26 @@ void execPairwiseTransform( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - NativeOpExecutioner::execPairwiseTransform(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hY, - hYShapeInfo, - dY, - dYShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - extraParams); + try { + NativeOpExecutioner::execPairwiseTransform(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hY, + hYShapeInfo, + dY, + dYShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execPairwiseTransformBool( @@ -288,21 +317,27 @@ void execPairwiseTransformBool( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - NativeOpExecutioner::execPairwiseBoolTransform(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hY, - hYShapeInfo, - dY, - dYShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - extraParams); + + try { + NativeOpExecutioner::execPairwiseBoolTransform(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hY, + hYShapeInfo, + dY, + dYShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -323,18 +358,22 @@ void execReduceFloat( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - NativeOpExecutioner::execReduceFloatScalar(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo); - + try { + NativeOpExecutioner::execReduceFloatScalar(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execReduceSame( @@ -346,18 +385,22 @@ void execReduceSame( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - NativeOpExecutioner::execReduceSameScalar(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo); - + try { + NativeOpExecutioner::execReduceSameScalar(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execReduceBool( @@ -368,19 +411,22 @@ void execReduceBool( void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - NativeOpExecutioner::execReduceBoolScalar(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo); - + try { + NativeOpExecutioner::execReduceBoolScalar(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execReduceLong( @@ -391,19 +437,22 @@ void execReduceLong( void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - NativeOpExecutioner::execReduceLongScalar(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo); - + try { + NativeOpExecutioner::execReduceLongScalar(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -424,28 +473,34 @@ void execReduceFloat2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); - auto hTADShapeInfo = tadPackX.primaryShapeInfo(); - auto hTADOffsets = tadPackX.primaryOffsets(); + auto hTADShapeInfo = tadPackX.primaryShapeInfo(); + auto hTADOffsets = tadPackX.primaryOffsets(); - NativeOpExecutioner::execReduceFloat(nullptr, opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); + NativeOpExecutioner::execReduceFloat(nullptr, opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + dimension, + dimensionLength, + hTADShapeInfo, + hTADOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execReduceBool2(Nd4jPointer *extraPointers, @@ -457,28 +512,34 @@ void execReduceBool2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); - NativeOpExecutioner::execReduceBool(nullptr, opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); + NativeOpExecutioner::execReduceBool(nullptr, opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + dimension, + dimensionLength, + hTADShapeInfo, + hTADOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execReduceSame2(Nd4jPointer *extraPointers, @@ -490,28 +551,34 @@ void execReduceSame2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); - NativeOpExecutioner::execReduceSame(nullptr, opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); + NativeOpExecutioner::execReduceSame(nullptr, opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + dimension, + dimensionLength, + hTADShapeInfo, + hTADOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execReduceLong2(Nd4jPointer *extraPointers, @@ -523,28 +590,34 @@ void execReduceLong2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); - NativeOpExecutioner::execReduceLong(nullptr, opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - dimension, - dimensionLength, - hTADShapeInfo, - hTADOffsets); + NativeOpExecutioner::execReduceLong(nullptr, opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + dimension, + dimensionLength, + hTADShapeInfo, + hTADOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -567,8 +640,13 @@ void execReduce3(Nd4jPointer *extraPointers, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - NativeOpExecutioner::execReduce3(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + NativeOpExecutioner::execReduce3(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, + dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -588,8 +666,13 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - NativeOpExecutioner::execReduce3Scalar(nullptr, opNum,hX,hXShapeInfo,dX, dXShapeInfo,extraParams,hY,hYShapeInfo,dY,dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + NativeOpExecutioner::execReduce3Scalar(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** * @@ -617,19 +700,31 @@ void execReduce3Tad(Nd4jPointer *extraPointers, void *dDimension, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - if (extraPointers == nullptr || extraPointers[2] == 0) { - NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - } else { - // going tad-way - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + if (extraPointers == nullptr || extraPointers[2] == 0) { + NativeOpExecutioner::execReduce3(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, dXShapeInfo, + extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, + yTadOnlyShapeInfo, yTadOffsets); + } else { + // going tad-way + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, + dimensionLength); - auto hTADShapeInfo = tadPack.primaryShapeInfo(); - auto hTADOffsets = tadPack.primaryOffsets(); + auto hTADShapeInfo = tadPack.primaryShapeInfo(); + auto hTADOffsets = tadPack.primaryOffsets(); - NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, nullptr, nullptr); + NativeOpExecutioner::execReduce3TAD(LaunchContext::defaultContext(), opNum, hX, hXShapeInfo, dX, + dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, + hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, hTADShapeInfo, + hTADOffsets, nullptr, nullptr); + } + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } } @@ -654,36 +749,9 @@ void execScalar( void *hScalar, Nd4jLong *hScalarShapeInfo, void *dScalar, Nd4jLong *dScalarShapeInfo, void *extraParams) { - NativeOpExecutioner::execScalar(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - hScalar, - hScalarShapeInfo, - dScalar, - dScalarShapeInfo, - extraParams); -} - -void execScalarBool( - Nd4jPointer *extraPointers, - int opNum, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hScalar, Nd4jLong *hScalarShapeInfo, - void *dScalar, Nd4jLong *dScalarShapeInfo, - void *extraParams) { - - NativeOpExecutioner::execScalarBool(nullptr, - opNum, + try { + NativeOpExecutioner::execScalar(nullptr, + opNum, hX, hXShapeInfo, dX, @@ -696,7 +764,43 @@ void execScalarBool( hScalarShapeInfo, dScalar, dScalarShapeInfo, - extraParams); + extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } +} + +void execScalarBool( + Nd4jPointer *extraPointers, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalar, Nd4jLong *hScalarShapeInfo, + void *dScalar, Nd4jLong *dScalarShapeInfo, + void *extraParams) { + try { + NativeOpExecutioner::execScalarBool(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + hScalar, + hScalarShapeInfo, + dScalar, + dScalarShapeInfo, + extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -714,18 +818,23 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, bool biasCorrected) { - NativeOpExecutioner::execSummaryStatsScalar(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - biasCorrected); + try { + NativeOpExecutioner::execSummaryStatsScalar(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + biasCorrected); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** * @@ -744,18 +853,23 @@ void execSummaryStats(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, bool biasCorrected) { - NativeOpExecutioner::execSummaryStats(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - biasCorrected); + try { + NativeOpExecutioner::execSummaryStats(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + biasCorrected); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** * @@ -779,27 +893,31 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, void *dDimension, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - NativeOpExecutioner::execSummaryStats(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffsets, - biasCorrected); - + NativeOpExecutioner::execSummaryStats(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + dimension, + dimensionLength, + tadShapeInfo, + tadOffsets, + biasCorrected); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -820,20 +938,24 @@ void execTransformFloat( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - - NativeOpExecutioner::execTransformFloat(nullptr, - opNum, - hX, - hXShapeInfo, - dZ, - dXShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - extraParams, - nullptr, - nullptr); + try { + NativeOpExecutioner::execTransformFloat(nullptr, + opNum, + hX, + hXShapeInfo, + dZ, + dXShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + extraParams, + nullptr, + nullptr); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execTransformSame( @@ -844,20 +966,24 @@ void execTransformSame( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - - NativeOpExecutioner::execTransformSame(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - extraParams, - nullptr, - nullptr); + try { + NativeOpExecutioner::execTransformSame(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + extraParams, + nullptr, + nullptr); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execTransformBool( @@ -868,20 +994,24 @@ void execTransformBool( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - - NativeOpExecutioner::execTransformBool(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - extraParams, - nullptr, - nullptr); + try { + NativeOpExecutioner::execTransformBool(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + extraParams, + nullptr, + nullptr); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execTransformAny( @@ -892,20 +1022,24 @@ void execTransformAny( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - - NativeOpExecutioner::execTransformAny(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - extraParams, - nullptr, - nullptr); + try { + NativeOpExecutioner::execTransformAny(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + extraParams, + nullptr, + nullptr); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execTransformStrict( @@ -916,20 +1050,24 @@ void execTransformStrict( void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - - NativeOpExecutioner::execTransformStrict(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - extraParams, - nullptr, - nullptr); + try { + NativeOpExecutioner::execTransformStrict(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + extraParams, + nullptr, + nullptr); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execReduce3All(Nd4jPointer *extraPointers, @@ -948,158 +1086,18 @@ void execReduce3All(Nd4jPointer *extraPointers, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - NativeOpExecutioner::execReduce3All(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); -} - - -template -void flattenGeneric(Nd4jPointer *extraPointers, - int offset, - char order, - void *vresult, - Nd4jLong *hZShapeInfo, - void *vinput, - Nd4jLong *inputShapeInfo) { - - auto hZ = reinterpret_cast(vresult); - auto input = reinterpret_cast(vinput); - - int numOnes = 0; - auto shape = shape::shapeOf(inputShapeInfo); - int wholeRank = shape::rank(inputShapeInfo); - for(int i = 0; i < wholeRank; i++) { - if(shape[i] == 1) - numOnes++; + NativeOpExecutioner::execReduce3All(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } - - - - //start at the given offset - hZ += offset; - char inputOrder = shape::order(inputShapeInfo); - auto len = shape::length(inputShapeInfo); - auto resultEleStride = shape::elementWiseStride(hZShapeInfo); - auto inputEleStride = shape::elementWiseStride(inputShapeInfo); - Nd4jLong numTads, stride; - int dimension, dimensionLength; - int rank = shape::rank(inputShapeInfo); - auto xStride = shape::stride(inputShapeInfo); - auto xShape = shape::shapeOf(inputShapeInfo); - - dimensionLength = 1; - if(order == 'f') { - dimension = 0; - } - else { - dimension = rank - 1; - } - stride = xStride[dimension]; - // numTads is product of length of all dimensions excluding - // the one we do the tad on - numTads = 1; - for (int i = 0; i < rank; i++) { - if (i != dimension) - numTads *= xShape[i]; - } - - if (inputOrder == order) { - if (resultEleStride == 1 && inputEleStride == 1) { - memcpy(hZ, input, len* sizeof(T)); - } - else if (resultEleStride >= 1 && inputEleStride >= 1) { - if (len < ELEMENT_THRESHOLD) { - - PRAGMA_OMP_SIMD - for (Nd4jLong i = 0; i < len; i++) { - hZ[i * resultEleStride] = input[i * inputEleStride]; - } - } - else { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong i = 0; i < len; i++) { - hZ[i * resultEleStride] = input[i * inputEleStride]; - } - } - } - else { - int idx = 0; - for(Nd4jLong i = 0; i < len; i++) - hZ[idx++] = input[shape::getIndexOffset(i, inputShapeInfo, len)]; - } - } - else { - int rank = shape::rank(inputShapeInfo); - auto xShape = shape::shapeOf(inputShapeInfo); - auto tadShape = xShape[dimension]; - - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(inputShapeInfo, dimension); - - PRAGMA_OMP_PARALLEL_FOR - for(int i = 0; i < numTads; i++) { - - Nd4jLong resultOffset; - - if (order == 'f') { - // 1. get c ordering coordinates - auto cIndexCoordinates = new Nd4jLong[rank - 1]; - Nd4jLong divisor = 1; - for (int dim = rank - 1; dim > 0; dim--) { - cIndexCoordinates[dim - 1] = (i / divisor) % xShape[dim]; - divisor *= xShape[dim]; - } - - - // 2. convert to f ordering index - int fIndex = 0; - Nd4jLong multiplier = 1; - for (int dim = 1; dim <= rank - 1; dim++) { - fIndex += cIndexCoordinates[dim - 1] * multiplier; - multiplier *= xShape[dim]; - } - - resultOffset = fIndex * tadShape; - delete[] cIndexCoordinates; - - } - else { - resultOffset = i * tadShape; - } - - auto tadOffset = tadPack.primaryOffsets()[i]; - for( int j = 0; j < tadShape; j++) { - - // TAD are returned in C ordering always - hZ[resultOffset + j] = input[tadOffset + j * stride]; - - } - } - } -} - - -/** - * Concatneate multi array of the same shape together - * along a particular dimension - */ -void concat( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfo, - Nd4jPointer *ddata, Nd4jPointer *dinputShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - Nd4jPointer *tadPointers, - Nd4jPointer *offsetPointers) { - - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, nd4j::SpecialMethods, ::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); } /** @@ -1116,39 +1114,14 @@ void specialConcat( Nd4jLong *hZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { + try { + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - - BUILD_SINGLE_SELECTOR(zType, nd4j::SpecialMethods, ::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); -} - -/** -* Append an input array -* to the end of a flat array -* in a particular order -* @param offset the offset of the array to start at -* @param order the order -* @param hZ the hZ array -* @param hZShapeInfo the shape info for te array -* @param input the input for the array -* @param inputShapeInfo the shape information for that array -*/ -void flatten( - Nd4jPointer *extraPointers, - int offset, - char order, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *input, Nd4jLong *inputShapeInfo, - void *dinput, Nd4jLong *dinputShapeInfo) { - - auto xType = nd4j::ArrayOptions::dataType(inputShapeInfo); - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - - if (xType != zType) - throw std::runtime_error("NativeOps::flatten requires all operands to have same data type"); - - BUILD_SINGLE_SELECTOR(xType, flattenGeneric, (extraPointers, offset, order, hZ, hZShapeInfo, input, inputShapeInfo), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(zType, nd4j::SpecialMethods,::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -1324,7 +1297,13 @@ void setGridLimit(int gridSize) { nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimensionLength) { auto pack = new TadPack(); - *pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + try { + *pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, dimension, dimensionLength); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } + return pack; } @@ -1421,9 +1400,14 @@ void pullRows(Nd4jPointer *extraPointers, Nd4jLong *tadOffsets, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + try { + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (hX, hXShapeInfo, hZ, hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, pullRowsGeneric, (hX, hXShapeInfo, hZ, hZShapeInfo, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } template @@ -1474,9 +1458,14 @@ void tear(Nd4jPointer *extraPointers, Nd4jLong *hZShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + try { + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearGeneric, (hX, hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, tearGeneric, (hX, hXShapeInfo, targets, hZShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -1488,9 +1477,14 @@ void average(Nd4jPointer *extras, int n, Nd4jLong length, bool propagate) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + try { + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::averageGeneric(hX, z, hZShapeInfo, n, length, propagate), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::averageGeneric(hX, z, hZShapeInfo, n, length, propagate), LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void accumulate(Nd4jPointer *extras, @@ -1500,10 +1494,14 @@ void accumulate(Nd4jPointer *extras, void *dz, Nd4jLong *dZShapeInfo, int n, Nd4jLong length) { + try { + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::accumulateGeneric(hX, hz, hZShapeInfo, n, length), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::accumulateGeneric(hX, hz, hZShapeInfo, n, length), LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void enableP2P(bool enable) { @@ -1613,14 +1611,20 @@ void shuffle(Nd4jPointer *extras, int *shuffleMap, Nd4jPointer *tadShapeInfo, Nd4jPointer *tadOffsets) { - auto xShape = reinterpret_cast(hXShapeInfo); - auto zShape = reinterpret_cast(hZShapeInfo); - auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); - auto tadOffset = reinterpret_cast(tadOffsets); + try { + auto xShape = reinterpret_cast(hXShapeInfo); + auto zShape = reinterpret_cast(hZShapeInfo); + auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); + auto tadOffset = reinterpret_cast(tadOffsets); - auto xType = nd4j::ArrayOptions::dataType(xShape[0]); + auto xType = nd4j::ArrayOptions::dataType(xShape[0]); - BUILD_SINGLE_SELECTOR(xType, shuffleGeneric, (hX, xShape, hz, zShape, N, shuffleMap, tadOnlyShapeInfo, tadOffset), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(xType, shuffleGeneric, + (hX, xShape, hz, zShape, N, shuffleMap, tadOnlyShapeInfo, tadOffset), LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -1633,27 +1637,6 @@ void setOmpMinThreads(int threads) { // TODO: to be implemented } -/* -void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraA, - void *extraB, - double scalarA, - double scalarB) { - // no-op; -} -*/ - int getDevice() { return 0; } @@ -1671,31 +1654,35 @@ void execScalarTad(Nd4jPointer *extraPointers, void *dDimension, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - NativeOpExecutioner::execScalar(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - hScalars, - hScalarShapeInfo, - dScalars, - dScalarShapeInfo, - dimension, - shape::length(hDimensionShape), - tadShapeInfo, - tadOffsets, - tadShapeInfoZ, - tadOffsetsZ); + NativeOpExecutioner::execScalar(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + hScalars, + hScalarShapeInfo, + dScalars, + dScalarShapeInfo, + dimension, + shape::length(hDimensionShape), + tadShapeInfo, + tadOffsets, + tadShapeInfoZ, + tadOffsetsZ); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execScalarBoolTad(Nd4jPointer *extraPointers, @@ -1711,44 +1698,53 @@ void execScalarBoolTad(Nd4jPointer *extraPointers, void *dDimension, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + try { + auto dimension = reinterpret_cast(hDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto dimension = reinterpret_cast(hDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - NativeOpExecutioner::execScalarBool(nullptr, - opNum, - hX, - hXShapeInfo, - dX, - dXShapeInfo, - extraParams, - hZ, - hZShapeInfo, - dZ, - dZShapeInfo, - hScalars, - hScalarShapeInfo, - dScalars, - dScalarShapeInfo, - dimension, - dimensionLength, - tadShapeInfo, - tadOffsets, - tadShapeInfoZ, - tadOffsetsZ); + NativeOpExecutioner::execScalarBool(nullptr, + opNum, + hX, + hXShapeInfo, + dX, + dXShapeInfo, + extraParams, + hZ, + hZShapeInfo, + dZ, + dZShapeInfo, + hScalars, + hScalarShapeInfo, + dScalars, + dScalarShapeInfo, + dimension, + dimensionLength, + tadShapeInfo, + tadOffsets, + tadShapeInfoZ, + tadOffsetsZ); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } const char * getDeviceName(int deviceId) { - if (!nameSet) { - name = reinterpret_cast(malloc(256 * sizeof(char))); + try { + if (!nameSet) { + name = reinterpret_cast(malloc(256 * sizeof(char))); - CHECK_ALLOC(name, "Failed to allocate new string buffer", 256); + CHECK_ALLOC(name, "Failed to allocate new string buffer", 256); - std::memset(name, 0, 256 * sizeof(char)); - nameSet = true; + std::memset(name, 0, 256 * sizeof(char)); + nameSet = true; - // TODO: provide proper CPU model name here - sprintf(name, "x86-compatible CPU"); + // TODO: provide proper CPU model name here + sprintf(name, "x86-compatible CPU"); + } + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } @@ -1768,8 +1764,12 @@ void execAggregate(Nd4jPointer *extraPointers,int opNum, void *realArguments, int numRealArguments, nd4j::DataType dtype) { - - BUILD_SINGLE_SELECTOR(dtype, NativeOpExecutioner::execAggregate, (nullptr, opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), FLOAT_TYPES); + try { + BUILD_SINGLE_SELECTOR(dtype, NativeOpExecutioner::execAggregate, (nullptr, opNum, arguments, numArguments, shapeArguments, numShapeArguments, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), FLOAT_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -1841,7 +1841,12 @@ void batchExecutor(Nd4jPointer *extraPointers, int maxReals, void *ptrToArguments, nd4j::DataType dtype) { - BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES); + try { + BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execAggregateBatch(Nd4jPointer *extraPointers, @@ -1855,7 +1860,12 @@ void execAggregateBatch(Nd4jPointer *extraPointers, int maxReals, void *ptrToArguments, nd4j::DataType dtype) { - BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES); + try { + BUILD_SINGLE_SELECTOR(dtype, _batchExecutor, (extraPointers, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments, dtype), FLOAT_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -1865,7 +1875,12 @@ void execRandom(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraArguments) { - NativeOpExecutioner::execRandom(nullptr, opNum, state, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + try { + NativeOpExecutioner::execRandom(nullptr, opNum, state, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execRandom3(Nd4jPointer *extraPointers, @@ -1878,8 +1893,12 @@ void execRandom3(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraArguments) { - - NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + try { + NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execRandom2(Nd4jPointer *extraPointers, @@ -1890,19 +1909,25 @@ void execRandom2(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraArguments) { - - NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + try { + NativeOpExecutioner::execRandom(nullptr, opNum, state, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) { - graph::RandomGenerator* generator = new graph::RandomGenerator(seed, seed); -// auto ptrBuf = reinterpret_cast(ptrToBuffer); -// auto buffer = new nd4j::random::RandomBuffer(seed, bufferSize, reinterpret_cast(ptrBuf)); -// -// nd4j::random::Xoroshiro128 generator(buffer); -// generator.refreshBuffer(); -// - return (Nd4jPointer) generator; + try { + auto generator = new graph::RandomGenerator(seed, seed); + + return (Nd4jPointer) generator; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + + return nullptr; + } } void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) { @@ -1953,7 +1978,12 @@ void sort(Nd4jPointer *extraPointers, void *hX, Nd4jLong *hXShapeInfo, void *dX, Nd4jLong *dXShapeInfo, bool descending) { - NativeOpExecutioner::execSort(hX, hXShapeInfo, descending); + try { + NativeOpExecutioner::execSort(hX, hXShapeInfo, descending); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void sortTad(Nd4jPointer *extraPointers, @@ -1964,7 +1994,12 @@ void sortTad(Nd4jPointer *extraPointers, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending) { - NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); + try { + NativeOpExecutioner::execSort(hX, hXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void sortCooIndices(Nd4jPointer *extraPointers, @@ -1972,7 +2007,12 @@ void sortCooIndices(Nd4jPointer *extraPointers, void *values, Nd4jLong length, int rank) { - NativeOpExecutioner::execSortCooIndices(indices, values, length, rank); + try { + NativeOpExecutioner::execSortCooIndices(indices, values, length, rank); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong *hXShapeInfo, Nd4jLong N, int *dz, float threshold) { @@ -1983,7 +2023,7 @@ Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong *hXShapeInf Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { auto hZ = new Nd4jLong[2];errno = 0; - +try { #if defined(_WIN32) || defined(_WIN64) _mmap(hZ, static_cast(length), fileName); #else @@ -1992,7 +2032,7 @@ Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong le nd4j_printf("Errno: %i\n", errno); throw std::runtime_error("Failed to open file for MMAP"); } - void * ptr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + void *ptr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); // check for failed allocation if (ptr == MAP_FAILED) @@ -2004,7 +2044,11 @@ Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong le #endif return hZ; - +} catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; +} } void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { @@ -2019,7 +2063,13 @@ void munmapFile(Nd4jPointer *extraPointers, Nd4jLong *ptrMap, Nd4jLong length) { } nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { - return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + try { + return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) { @@ -2061,8 +2111,14 @@ FORCEINLINE int estimateThresholdGeneric(Nd4jPointer *extraPointers, Nd4jPointer int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong *hXShapeInfo, int N, float threshold) { - auto xType = ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES); + try { + auto xType = ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 0; + } } Nd4jLong getShapeListSize(nd4j::ShapeList* list) { @@ -2122,9 +2178,15 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D } nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs); + return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp *op, Nd4jPointer* inputShapes, int numInputShapes, double *tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { @@ -2147,16 +2209,28 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D } nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); + return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - auto context = reinterpret_cast(opContext); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + auto context = reinterpret_cast(opContext); - return op->execute(context); + return op->execute(context); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 20; + } } Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { @@ -2234,34 +2308,6 @@ Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4 outputs[e]->streamline(shape::order(reinterpret_cast(outputShapes[e]))); } -/* - if (!isInplace) { - if (hZ->size() != numOutputs) { - return ND4J_STATUS_BAD_OUTPUT; - } - - for (int e = 0; e < numOutputs; e++) { - auto buffer = (T *) outputBuffers[e]; - auto shape = (int *) outputShapes[e]; - nd4j::NDArray tmp(buffer, shape); - - if (tmp.lengthOf() != hZ->at(e)->lengthOf()) { - nd4j_printf("Provided output array for [%s] has length of %i, but actual hZ has length of %i\n", op->getOpName()->c_str(), tmp.lengthOf(), hZ->at(e)->lengthOf()); - return ND4J_STATUS_BAD_OUTPUT; - } - - tmp.assign(hZ->at(e)); - } - } else { - // if op is inplace, our ResultSet holds pointers - hZ->purge(); - } - - - delete hZ; - -*/ - for (auto v: inputs) delete v; @@ -2273,16 +2319,28 @@ Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4 int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; + } } int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { - auto graph = nd4j::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); + try { + auto graph = nd4j::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); - nd4j::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); + nd4j::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); - return ND4J_STATUS_OK; + return ND4J_STATUS_OK; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; + } } static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { @@ -2478,7 +2536,13 @@ Nd4jStatus execCustomOpWithScope_(Nd4jPointer *extraPointers, nd4j::graph::Graph } Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - return execCustomOpWithScope_(extraPointers, reinterpret_cast(state), opHash, scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs); + try { + return execCustomOpWithScope_(extraPointers, reinterpret_cast(state), opHash, scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; + } } void deleteResultWrapper(Nd4jPointer ptr) { @@ -2704,73 +2768,98 @@ void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, int* hIindexes, int* dIindexes) { + try { - int numThreads = omp_get_max_threads(); + int numThreads = omp_get_max_threads(); - PRAGMA_OMP_PARALLEL_THREADS(numThreads) - { - for (int i = 0; i < numOfSubArrs; ++i) { + PRAGMA_OMP_PARALLEL_THREADS(numThreads) + { + for (int i = 0; i < numOfSubArrs; ++i) { - int threadIndex = omp_get_thread_num(); - const auto xIndex = hIindexes[i]; - const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads; + int threadIndex = omp_get_thread_num(); + const auto xIndex = hIindexes[i]; + const bool isOwner = xIndex < numThreads ? threadIndex == xIndex : threadIndex == xIndex % numThreads; - if (!isOwner) - continue; - - NDArray inSubArr(reinterpret_cast(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), hXShapeInfo); - NDArray updSubArr(reinterpret_cast(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), hYShapeInfo); - - if (inSubArr.lengthOf() != updSubArr.lengthOf()) { - continue; - } - - switch (opCode) { - case 0: - inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); - break; - case 1: - inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); - break; - case 2: - inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); - break; - case 3: - inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); - break; - case 4: - inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); - break; - case 5: - inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); - break; - case 6: - inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); - break; - default: + if (!isOwner) continue; + + NDArray inSubArr( + reinterpret_cast(hX) + (hXOffsets[hIindexes[i]] * DataTypeUtils::sizeOf(hXShapeInfo)), + hXShapeInfo); + NDArray updSubArr(reinterpret_cast(hY) + (hYOffsets[i] * DataTypeUtils::sizeOf(hXShapeInfo)), + hYShapeInfo); + + if (inSubArr.lengthOf() != updSubArr.lengthOf()) { + continue; + } + + switch (opCode) { + case 0: + inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); + break; + case 1: + inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); + break; + case 2: + inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); + break; + case 3: + inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); + break; + case 4: + inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); + break; + case 5: + inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); + break; + case 6: + inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); + break; + default: + continue; + } } } + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } } void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { - auto p = reinterpret_cast(debugInfo); - NDArray array(buffer, shapeInfo); - nd4j::DebugHelper::retrieveDebugStatistics(p, &array); + try { + auto p = reinterpret_cast(debugInfo); + NDArray array(buffer, shapeInfo); + nd4j::DebugHelper::retrieveDebugStatistics(p, &array); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) { - auto buf = reinterpret_cast(p); - int cnt = 0; - for (int i = 0; i < len; i++) - cnt += buf[cnt]; + try { + auto buf = reinterpret_cast(p); + int cnt = 0; + for (int i = 0; i < len; i++) + cnt += buf[cnt]; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty) { - auto buffer = new ConstantDataBuffer(); - *buffer = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo(ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); - return buffer; + try { + auto buffer = new ConstantDataBuffer(); + *buffer = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo( + ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); + return buffer; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) { @@ -2790,7 +2879,13 @@ nd4j::ConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *dat } nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) { - return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); + try { + return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) { @@ -2808,7 +2903,13 @@ Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) { nd4j::graph::Context* createGraphContext(int nodeId) { - return new nd4j::graph::Context(nodeId); + try { + return new nd4j::graph::Context(nodeId); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) { return &ptr->randomGenerator(); @@ -2872,32 +2973,38 @@ int dataTypeFromNpyHeader(void *header) { } Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - unsigned int shapeSize = arr.shape.size(); - std::vector shape(shapeSize); - bool _empty = false; - for(unsigned int i = 0; i < shapeSize; i++) { - shape[i] = arr.shape[i]; + try { + cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for (unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; - if (arr.shape[i] == 0) - _empty = true; + if (arr.shape[i] == 0) + _empty = true; + } + + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { + if (shapeSize > 0) + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + } + return reinterpret_cast(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; } - - auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); - - Nd4jLong *shapeBuffer; - if (shape.size() == 1 && shape[0] == 0) { - // scalar case - shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype); - } else if (_empty) { - if (shapeSize > 0) - shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - else - shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype); - } else { - shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - } - return reinterpret_cast(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); } void sortByKey(Nd4jPointer *extraPointers, @@ -2906,10 +3013,15 @@ void sortByKey(Nd4jPointer *extraPointers, void *y, Nd4jLong *yShapeInfo, void *dy, Nd4jLong *dyShapeInfo, bool descending) { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); - BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void sortByValue(Nd4jPointer *extraPointers, @@ -2918,11 +3030,15 @@ void sortByValue(Nd4jPointer *extraPointers, void *y, Nd4jLong *yShapeInfo, void *dy, Nd4jLong *dyShapeInfo, bool descending) { + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); - - BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void sortTadByKey(Nd4jPointer *extraPointers, @@ -2933,10 +3049,15 @@ void sortTadByKey(Nd4jPointer *extraPointers, int *dimension, int dimensionLength, bool descending) { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); - BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void sortTadByValue(Nd4jPointer *extraPointers, @@ -2947,24 +3068,35 @@ void sortTadByValue(Nd4jPointer *extraPointers, int *dimension, int dimensionLength, bool descending) { - auto xType = ArrayOptions::dataType(xShapeInfo); - auto yType = ArrayOptions::dataType(yShapeInfo); + try { + auto xType = ArrayOptions::dataType(xShapeInfo); + auto yType = ArrayOptions::dataType(yShapeInfo); - BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } const char* runLightBenchmarkSuit(bool printOut) { - nd4j::LightBenchmarkSuit suit; - auto result = suit.runSuit(); + try { + nd4j::LightBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) + nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length()+1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char) 0x0; - return chars; + return chars; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getCachedMemory(int deviceId) { @@ -2972,17 +3104,23 @@ Nd4jLong getCachedMemory(int deviceId) { } const char* runFullBenchmarkSuit(bool printOut) { - nd4j::FullBenchmarkSuit suit; - auto result = suit.runSuit(); + try { + nd4j::FullBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) + nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length()+1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char) 0x0; - return chars; + return chars; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } nd4j::LaunchContext* defaultLaunchContext() { @@ -3017,8 +3155,14 @@ Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { return nullptr; } +int lastErrorCode() { + return nd4j::LaunchContext::defaultContext()->errorReference()->errorCode(); +} + +const char* lastErrorMessage() { + return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage(); +} -BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void shuffleGeneric, (void**, Nd4jLong**, void**, Nd4jLong**, int, int*, Nd4jLong**, Nd4jLong**), LIBND4J_TYPES); diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index b3573c7ab..fcb473820 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -38,19 +38,22 @@ #include #include #include -#include #include #include #include +#include #include +#include +#include #include #include #include #include -#include #include #include +#include #include +#include using namespace nd4j; @@ -152,6 +155,39 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc, throw cuda_exception::build("execPairwiseBoolTransform failed", res); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams) { + + auto stream = lc->getCudaStream(); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); + + if (yType != xType || zType != xType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform both operands must have same data type", xType, yType); + + dim3 launchDims(256, 1024, 16384); + + BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseIntTransform failed", res); +} + //////////////////////////////////////////////////////////////////////// void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc, int opNum, @@ -252,6 +288,81 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, throw cuda_exception::build("execInverseBroadcastBool failed", res); } + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { + + auto stream = lc->getCudaStream(); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); + + if (yType != xType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); + + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("F3B opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execBroadcastBool failed", res); +} + +void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); + + if (yType != xType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); + + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("F3BI opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execInverseBroadcastInt failed", res); +} + //////////////////////////////////////////////////////////////////////// /** * @@ -475,12 +586,12 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc, auto numBlocks = shape::length(hZShapeInfo); dim3 launchDims(numBlocks, 256, 32768); - if (zType != nd4j::DataType::INT64) - throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT64 type", zType); + if (zType != nd4j::DataType::INT64 && zType != nd4j::DataType::INT32) + throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType); auto dz = reinterpret_cast(dZ); - BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); @@ -567,12 +678,12 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, // FIXME: we want Z to be one of integer types //if (!DataTypeUtils::isZ(zType)) // throw nd4j::datatype_exception("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have one of integer types") - if (zType != nd4j::DataType::INT64) - throw nd4j::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT64 data type", zType); + if (zType != nd4j::DataType::INT64 && zType != nd4j::DataType::INT32) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT32/INT64 data type", zType); auto dz = reinterpret_cast(dZ); - BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream, + BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, @@ -580,7 +691,7 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, nullptr, 0, 1, allocationPointer, reductionPointer, - nullptr, nullptr), LIBND4J_TYPES); + nullptr, nullptr), LIBND4J_TYPES, INDEXING_TYPES); // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); if (res != 0) @@ -1114,6 +1225,75 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, throw cuda_exception::build("execScalarBool B failed", res); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalar, Nd4jLong *hScalarShapeInfo, + void *dScalar, Nd4jLong *dScalarShapeInfo, + void *extraParams, bool allowParallelism) { + + auto stream = lc->getCudaStream(); + + dim3 launchDims = dim3(256, 512, 8192); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType) ) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execScalarInt failed", res); +} + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *extraParams, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalars, Nd4jLong *hScalarShapeInfo, + void *dScalars, Nd4jLong *dScalarShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + auto stream = lc->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType) ) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execScalarInt B failed", res); +} + //////////////////////////////////////////////////////////////////////// void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, int opNum, diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index e75aa422c..7e74c3237 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -68,21 +68,6 @@ int minThreads = 32; __constant__ char deviceConstantMemory[49152]; -typedef struct { - long streamId; - long callId; -} __syncInfo; - -typedef __syncInfo SyncInfo; - - -// this method isn't used, left here for legacy and caution purposes -// TLDR: don't use this way, it sucks -void CUDART_CB syncCallback(cudaStream_t stream, cudaError_t status, void *data){ - SyncInfo *sync = reinterpret_cast(data); - - //printf("Finished stream: [%i], kernel call: [%i]\n", sync->streamId, sync->callId); -} // this method just does type conversion in fancy way int getDeviceId(Nd4jPointer ptrToDeviceId) { @@ -250,9 +235,14 @@ void execPairwiseTransform( Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execPairwiseTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -265,9 +255,14 @@ void execPairwiseTransformBool(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execPairwiseBoolTransform(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, + dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -279,9 +274,14 @@ void execSummaryStatsScalar(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, bool biasCorrected) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execSummaryStatsScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, + hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -295,24 +295,30 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { + try { + //Nd4jLong *tadOnlyShapeInfo = reinterpret_cast(extraPointers[0]); + //Nd4jLong *tadOffsets = reinterpret_cast(extraPointers[1]); + //Nd4jLong *tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[2]); + //Nd4jLong *tadOffsetsZ = reinterpret_cast(extraPointers[3]); - //Nd4jLong *tadOnlyShapeInfo = reinterpret_cast(extraPointers[0]); - //Nd4jLong *tadOffsets = reinterpret_cast(extraPointers[1]); - //Nd4jLong *tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[2]); - //Nd4jLong *tadOffsetsZ = reinterpret_cast(extraPointers[3]); + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); - - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); - auto tadOffsets = reinterpret_cast(extraPointers[11]); - auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); - auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); + auto tadOffsets = reinterpret_cast(extraPointers[11]); + auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); + auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, + tadOffsetsZ); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -338,38 +344,33 @@ void execBroadcast( void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { -/* - cudaEvent_t start; - cudaEventCreateWithFlags(&start, cudaEventDisableTiming); - timespec tsX; - timespec tsY; - clock_gettime(CLOCK_REALTIME, &tsX); -*/ - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); - auto tadOffsets = reinterpret_cast(extraPointers[11]); - auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); - auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto tadOnlyShapeInfo = reinterpret_cast(extraPointers[10]); + auto tadOffsets = reinterpret_cast(extraPointers[11]); + auto tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[12]); + auto tadOffsetsZ = reinterpret_cast(extraPointers[13]); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3 opNum:[%i]\n", opNum); + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("F3 opNum:[%i]\n", opNum); - //Nd4jLong *tadOnlyShapeInfo = reinterpret_cast(extraPointers[0]); - //Nd4jLong *tadOffsets = reinterpret_cast(extraPointers[1]); - //Nd4jLong *tadOnlyShapeInfoZ = reinterpret_cast(extraPointers[2]); - //Nd4jLong *tadOffsetsZ = reinterpret_cast(extraPointers[3]); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execBroadcast(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execBroadcast(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -390,9 +391,14 @@ void execReduceFloat(Nd4jPointer *extraPointers, void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduceFloatScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, + hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -403,9 +409,14 @@ void execReduceSame(Nd4jPointer *extraPointers, void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSameScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduceSameScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, + hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -418,13 +429,22 @@ void execReduceSame2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast(hDimension), shape::length(hDimensionShape)); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, + reinterpret_cast(hDimension), + shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), tadPack.specialOffsets()); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduceSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, + dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -437,13 +457,22 @@ void execReduceLong2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast(hDimension), shape::length(hDimensionShape)); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, + reinterpret_cast(hDimension), + shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceLong(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), tadPack.specialOffsets()); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduceLong(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, + dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -454,30 +483,37 @@ void execReduceLong(Nd4jPointer *extraPointers, void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - auto stream = reinterpret_cast(extraPointers[1]); - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("LF7 opNum:[%i]\n", opNum); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("LF7 opNum:[%i]\n", opNum); + auto reductionPointer = reinterpret_cast(extraPointers[4]); - auto reductionPointer = reinterpret_cast(extraPointers[4]); + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (zType != nd4j::DataType::INT64) + throw datatype_exception::build("execReduceLong wrong Z data type", nd4j::DataType::INT64, zType); - if (zType != nd4j::DataType::INT64) - throw datatype_exception::build("execReduceLong wrong Z data type", nd4j::DataType::INT64, zType); + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks, blockWidth, 32768); - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, + dZ, dZShapeInfo, hXShapeInfo, nullptr, 0, reductionPointer, + dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hXShapeInfo, nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -490,13 +526,22 @@ void execReduceBool2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast(hDimension), shape::length(hDimensionShape)); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, + reinterpret_cast(hDimension), + shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), tadPack.specialOffsets()); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduceBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, + dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -507,30 +552,37 @@ void execReduceBool(Nd4jPointer *extraPointers, void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); + auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); - auto stream = reinterpret_cast(extraPointers[1]); - auto hTADShapeInfo = reinterpret_cast(extraPointers[9]); - auto dTADShapeInfo = reinterpret_cast(extraPointers[10]); + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("BF7 opNum:[%i]\n", opNum); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("BF7 opNum:[%i]\n", opNum); + auto reductionPointer = reinterpret_cast(extraPointers[4]); - auto reductionPointer = reinterpret_cast(extraPointers[4]); + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (zType != nd4j::DataType::BOOL) + throw std::runtime_error("execReduceBool requires Z operand to have BOOL type"); - if (zType != nd4j::DataType::BOOL) - throw std::runtime_error("execReduceBool requires Z operand to have BOOL type"); + auto xLength = shape::length(hXShapeInfo); + auto blockWidth = 256; + auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); + dim3 launchDims(numBlocks, blockWidth, 32768); - auto xLength = shape::length(hXShapeInfo); - auto blockWidth = 256; - auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth); - dim3 launchDims(numBlocks, blockWidth, 32768); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, + ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, + dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, + dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -554,13 +606,22 @@ void execIndexReduce(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast(hDimension), shape::length(hDimensionShape)); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, + reinterpret_cast(hDimension), + shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduce(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), tadPack.specialOffsets()); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execIndexReduce(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, + dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -582,13 +643,22 @@ void execReduceFloat2(Nd4jPointer *extraPointers, void *dZ, Nd4jLong *dZShapeInfo, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast(hDimension), shape::length(hDimensionShape)); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, + reinterpret_cast(hDimension), + shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduceFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), tadPack.specialOffsets()); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduceFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, + dZ, dZShapeInfo, dimension, dimensionLength, tadPack.specialShapeInfo(), + tadPack.specialOffsets()); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } /** @@ -607,9 +677,14 @@ void execIndexReduceScalar( void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo){ - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execIndexReduceScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, + hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -619,12 +694,17 @@ void execTransformSame(Nd4jPointer *extraPointers,int opNum, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { + try { + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); + auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execTransformSame(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -634,12 +714,17 @@ void execTransformBool(Nd4jPointer *extraPointers,int opNum, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { + try { + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); + auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[0] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[1] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execTransformBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -649,12 +734,18 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto streamSpecial = reinterpret_cast(extraPointers[4]); + LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], + reinterpret_cast(extraPointers[6])); - auto stream = reinterpret_cast(extraPointers[1]); - auto streamSpecial = reinterpret_cast(extraPointers[4]); - LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast(extraPointers[6])); - - NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr); + NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, extraParams, nullptr, nullptr); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -664,12 +755,17 @@ void execTransformStrict(Nd4jPointer *extraPointers,int opNum, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { + try { + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); + auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformStrict(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execTransformStrict(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -679,55 +775,19 @@ void execTransformFloat(Nd4jPointer *extraPointers,int opNum, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { + try { + auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); + auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - auto tadShapeInfo = reinterpret_cast(extraPointers != nullptr ? extraPointers[10] : nullptr); - auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } - -/** - * Append an input array - * to the end of a flat array - * in a particular order - * @param offset the offset of the array to start at - * @param order the order - * @param dZ the dZ array - * @param dZShapeInfo the shape info for te array - * @param input the input for the array - * @param inputShapeInfo the shape information for that array - */ -void flatten(Nd4jPointer *extraPointers, - int offset, - char order, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *hInput, Nd4jLong *hInputShapeInfo, - void *dInput, Nd4jLong *dInputShapeInfo) { - - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - auto hYShapeInfo = reinterpret_cast(extraPointers[7]); - - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F22 opNum:[7]\n"); - - // int *allocPointer = reinterpret_cast(extraPointers[3]); - - dim3 launchDims(256, 256, 2048); - - if (nd4j::Environment::getInstance()->isVerbose() && launchDims.x == 1) - printf("AF222 opNum:[7]\n"); - - auto type = nd4j::ArrayOptions::dataType(hInputShapeInfo); - BUILD_SINGLE_SELECTOR(type, flattenKernelGeneric, (launchDims, stream, extraPointers, offset, order, dZ, dZShapeInfo, dInput, dInputShapeInfo), LIBND4J_TYPES); - - DEBUG_KERNEL(stream, -1); -} - - - void checkP2P() { int curDevice = 0; @@ -821,23 +881,28 @@ bool isP2PAvailable() { void initializeDevicesAndFunctions() { - int devCnt = 0; - cudaGetDeviceCount(&devCnt); - deviceProperties = new cudaDeviceProp[devCnt]; - for (int i = 0; i < devCnt; i++) { - cudaSetDevice(i); - cudaGetDeviceProperties(&deviceProperties[i], i); + try { + int devCnt = 0; + cudaGetDeviceCount(&devCnt); + deviceProperties = new cudaDeviceProp[devCnt]; + for (int i = 0; i < devCnt; i++) { + cudaSetDevice(i); + cudaGetDeviceProperties(&deviceProperties[i], i); - cudaDeviceSetLimit(cudaLimitStackSize, 4096); - } + cudaDeviceSetLimit(cudaLimitStackSize, 4096); + } - cudaSetDevice(0); + cudaSetDevice(0); - checkP2P(); + checkP2P(); - // enabling p2p gpu access if it's supported - if (supportedP2P && devCnt > 1) - enableP2P(allowedP2P); + // enabling p2p gpu access if it's supported + if (supportedP2P && devCnt > 1) + enableP2P(allowedP2P); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void initializeFunctions(Nd4jPointer *functions) { @@ -867,8 +932,10 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { Nd4jPointer pointer; // cudaHostAllocMapped |cudaHostAllocPortable auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize, cudaHostAllocDefault); - if (res != 0) - throw nd4j::cuda_exception::build("cudaHostAlloc(...) failed", res); + if (res != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed"); + } return pointer; } @@ -884,8 +951,11 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { Nd4jPointer pointer; auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize); - if (res != 0) - throw nd4j::cuda_exception::build("cudaMalloc(...) failed", res); + if (res != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed"); + } + return pointer; } @@ -896,8 +966,11 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { */ int freeHost(Nd4jPointer pointer) { auto res = cudaFreeHost(reinterpret_cast(pointer)); - if (res != 0) - throw nd4j::cuda_exception::build("cudaFreeHost(...) failed", res); + if (res != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaFreeHost failed"); + } + return 1L; } @@ -909,10 +982,14 @@ int freeHost(Nd4jPointer pointer) { */ int freeDevice(Nd4jPointer pointer, int deviceId) { auto res = cudaFree(reinterpret_cast(pointer)); - if (res != 0) - throw nd4j::cuda_exception::build("cudaFree(...) failed", res); - return 1L; + // we're intentionally skipping + if (res != 0 && res != 1) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaFree failed"); + } + + return res == 0 ? 1L : 0L; } @@ -921,22 +998,13 @@ Nd4jPointer createContext() { } Nd4jPointer createStream() { - /* - Nd4jPointer nativeStream = (Nd4jPointer) malloc(sizeof(cudaStream_t)); - CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); - - cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); - checkCudaErrors(dZ); - if (dZ != 0) - throw std::runtime_error("cudaStreamCreate(...) failed"); - - return nativeStream; - */ auto stream = new cudaStream_t(); auto dZ = cudaStreamCreate(stream); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaStreamCreate(...) failed", dZ); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaStreamCreate failed"); + } return stream; } @@ -947,9 +1015,10 @@ Nd4jPointer createEvent() { CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t)); auto dZ = cudaEventCreateWithFlags(reinterpret_cast(&nativeEvent), cudaEventDisableTiming); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaEventCreateWithFlags(...) failed", dZ); - + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventCreateWithFlags failed"); + } return nativeEvent; } @@ -959,8 +1028,10 @@ int registerEvent(Nd4jPointer event, Nd4jPointer stream) { auto pStream = reinterpret_cast(stream); auto dZ = cudaEventRecord(*pEvent, *pStream); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaEventRecord(...) failed", dZ); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventRecord failed"); + } return 1; } @@ -1048,8 +1119,11 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j kind = cudaMemcpyDeviceToDevice; } break; - default: - throw nd4j::cuda_exception::build("UNDEFINED MEMCPY!\n", 119); + default: { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("UNDEFNED MEMCPY"); + return 0; + } } auto dZ = cudaMemcpyAsync(reinterpret_cast(dst), const_cast(reinterpret_cast(src)), static_cast(size), kind, *pStream); @@ -1058,7 +1132,8 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j printf("Failed on [%lu] -> [%lu], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast(dZ)); fflush(stdout); fflush(stderr); - throw nd4j::cuda_exception::build("cudaMemcpyAsync(...) failed", dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyAsync failed"); } return 1; @@ -1066,8 +1141,10 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { auto dZ = cudaMemset(reinterpret_cast(dst), value, static_cast(size)); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaMemset(...) failed", dZ); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemset failed"); + } return 1; } @@ -1076,8 +1153,10 @@ int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointe auto pStream = reinterpret_cast(reserved); auto dZ = cudaMemsetAsync(reinterpret_cast(dst), value, static_cast(size), *pStream); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaMemsetAsync(...) failed", dZ); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemsetAsync failed"); + } return 1; } @@ -1085,8 +1164,10 @@ int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointe int destroyEvent(Nd4jPointer event) { auto pEvent = reinterpret_cast(&event); auto dZ = cudaEventDestroy(*pEvent); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaEvenDestroy(...) failed", dZ); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventDestroy failed"); + } return 1; } @@ -1095,8 +1176,10 @@ int streamSynchronize(Nd4jPointer stream) { auto pStream = reinterpret_cast(stream); auto dZ = cudaStreamSynchronize(*pStream); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaStreamSynchronize(...) failed", dZ); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaStreamSynchronize failed"); + } return 1L; } @@ -1105,8 +1188,10 @@ int eventSynchronize(Nd4jPointer event) { auto pEvent = reinterpret_cast(&event); auto dZ = cudaEventSynchronize(*pEvent); - if (dZ != 0) - throw nd4j::cuda_exception::build("cudaEventSynchronize(...) failed", dZ); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaEventSynchronize failed"); + } return 1L; } @@ -1162,268 +1247,6 @@ const char * getDeviceName(int device) { return deviceProperties[device].name; } -/////////////////////////////////////////////////////////////////// -template -__global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) { - - __shared__ int arrIdx, blocksPerArr; - __shared__ T *x, *z; - __shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLen, arrLenZ, arrLenPerBlock, start, end; - - if (threadIdx.x == 0) { - blocksPerArr = (gridDim.x - gridDim.x % numOfArrs) / numOfArrs; // floor - arrIdx = blockIdx.x / blocksPerArr; - if (arrIdx >= numOfArrs) - arrIdx = numOfArrs - 1; - x = reinterpret_cast(reinterpret_cast(pVx)[arrIdx]); - z = reinterpret_cast(reinterpret_cast(pVz)[arrIdx]); - xShapeInfo = reinterpret_cast(pxShapeInfo)[arrIdx]; - zShapeInfo = reinterpret_cast(pzShapeInfo)[arrIdx]; - - arrLen = shape::length(xShapeInfo); - arrLenZ = shape::length(zShapeInfo); - arrLenPerBlock = (arrLen + blocksPerArr - arrLen % blocksPerArr) / blocksPerArr; // ceil - - start = arrLenPerBlock * (blockIdx.x % blocksPerArr); - end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock); - } - __syncthreads(); - - for (Nd4jLong i = threadIdx.x + start; i < end; i += blockDim.x) { - auto zOffset = shape::getIndexOffset(i, zShapeInfo, arrLenZ); - auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLen); - //printf("z[%i][%lld] = x[%i][%lld]\n", arrIdx, zOffset, arrIdx, xOffset); - z[zOffset] = x[xOffset]; - } -} -template -__host__ static void concatCudaLauncher(const int numOfArrs, cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) { - //int blocks = numOfArrs * 16; // >> 1 << 2); - //nd4j_printf("gridDim.x is %i\n", blocks); - //if (blocks > 8192) - // blocks = 8192; // restrict grid dims to 8K max - concatCuda<<>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo); - nd4j::DebugHelper::checkErrorCode(stream, "concat(...) failed"); -} -BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES); - -static void -specialBufferAndShapeWithOffset(void* vZ, Nd4jLong* hZShapeInfo, Nd4jLong* dZShapeInfo, std::vector const& idx, void*& outBuffer, Nd4jLong*& outShape) { - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - const int rank = shape::rank(hZShapeInfo); - Nd4jLong* newShape = new Nd4jLong[shape::shapeInfoLength(rank)]; - //ALLOCATE(newShape, nullptr, , Nd4jLong) - auto shapeSize = shape::shapeInfoByteLength(rank); - memcpy(newShape, hZShapeInfo, shapeSize); - - auto shapeOf = shape::shapeOf(newShape); - auto stridesOf = shape::stride(newShape); - - Nd4jLong offset(0), subArrLen(1); - int n(2), first, last, stride; - - for (int d = rank - 1; d >= 0; --d) { - - if (idx[n * d] != idx[n * d + 1]) { - auto axeDim = shape::sizeAt(hZShapeInfo, d); - first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + axeDim + 1; - last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + axeDim + 1; - stride = 1; - - shapeOf[d] = (last - first + stride - 1) / stride; // ceil (last - first) / stride; - offset += first * stridesOf[d]; - - if(shapeOf[d] != 1) - stridesOf[d] *= stride; - } - - subArrLen *= shapeOf[d]; - } - - // check if there is possibility to set ews = 1 - //shape::setEws(newShape, subArrLen); - - //makeBothBuffersActual(); - outBuffer = (void*)((int8_t*)vZ + offset * DataTypeUtils::sizeOfElement(zType)); - cudaError_t err = cudaMalloc(&outShape, shapeSize); - if (err != 0) { - printf("Cannot allocate memory with error %d\n", err); - throw std::runtime_error("Cannot allocate memory for shape"); - } - cudaMemcpy(outShape, newShape, shapeSize, cudaMemcpyHostToDevice); - delete [] newShape; -} - -/** - * Concatneate multi array of the same shape together - * along a particular dimension - */ -void concat( - Nd4jPointer *extraPointers, - int dimension, - int numArrays, - Nd4jPointer *data, Nd4jPointer *inputShapeInfo, - Nd4jPointer *ddata, Nd4jPointer *dinputShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { - - auto stream = reinterpret_cast(extraPointers[1]); - - auto hXShapeInfo = hZShapeInfo; - auto hShapePointers = reinterpret_cast(inputShapeInfo); - auto dShapePointers = reinterpret_cast(dinputShapeInfo); - // numArrays will be used as number of TADs, so each block process 1 input - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - auto axis = dimension; - - const int rank = shape::rank(hZShapeInfo); //reinterpret_cast(inputShapeInfo[0])); - const int rank2 = 2 * rank; - std::vector> indices(numArrays, std::vector(rank2,0)); - - // take into account indices for first array - auto axisSize = shape::sizeAt(reinterpret_cast(inputShapeInfo[0]), axis); - indices[0][2 * axis + 1] = axisSize; - - for(int i = 1; i < numArrays; ++i) { - indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from - indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + shape::sizeAt(reinterpret_cast(inputShapeInfo[i]), axis); // index end with (excluding) - } - - std::vector outSubArrsBuffs(numArrays); - std::vector outSubArrsShapes(numArrays); - for(int i = 0; i < numArrays; ++i) { - specialBufferAndShapeWithOffset(dZ, hZShapeInfo, dZShapeInfo, indices[i], outSubArrsBuffs[i], outSubArrsShapes[i]); - } - - LaunchContext context(stream); - PointersManager manager(&context, "concat"); - void* dOutBuffers = manager.replicatePointer(outSubArrsBuffs.data(), outSubArrsBuffs.size() * sizeof(void*)); - void* dInBuffers = manager.replicatePointer(ddata, numArrays * sizeof(void*)); - void* dInShapeInfo = manager.replicatePointer(dShapePointers, numArrays * sizeof(Nd4jLong*)); - void* dOutShapeInfo = manager.replicatePointer(outSubArrsShapes.data(), outSubArrsShapes.size() * sizeof(Nd4jLong*)); - - BUILD_SINGLE_SELECTOR(zType, concatCudaLauncher, (numArrays, stream, dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES); - manager.synchronize(); - - cudaError_t err; - for(int i = 0; i < numArrays; ++i) { - err = cudaFree(outSubArrsShapes[i]); - if (err != 0) { - printf("Error %d occured when shape %i was deallocating.\n", err, i); - throw std::runtime_error("Cannot deallocate memory for shapes."); - } - } -} - -/** - * Concatneate multi array of the same shape together - * along a particular dimension - */ -// void concat( -// Nd4jPointer *extraPointers, -// int dimension, -// int numArrays, -// Nd4jPointer *data, Nd4jPointer *inputShapeInfo, -// Nd4jPointer *ddata, Nd4jPointer *dinputShapeInfo, -// void *hZ, Nd4jLong *hZShapeInfo, -// void *dZ, Nd4jLong *dZShapeInfo, -// Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { -// -// cudaStream_t *stream = reinterpret_cast(extraPointers[1]); -// auto hXShapeInfo = hZShapeInfo; -// auto hShapePointers = reinterpret_cast(inputShapeInfo); -// // numArrays will be used as number of TADs, so each block process 1 input -// -// int smem = 8192; -// bool isVstack = false; -// bool isScalar = true; -// bool isHstack = false; -// -// for (int i = 0; i < numArrays; i++) { -// if (!shape::isScalar(hShapePointers[i])) { -// isScalar = false; -// break; -// } -// } -// -// if (!isScalar && dimension == 0 && shape::rank(hZShapeInfo) == 2 && shape::order(hZShapeInfo) == 'c' ) { -// isVstack = true; -// for (int i = 0; i < numArrays; i++) { -// if (!shape::isVector(hShapePointers[i]) || shape::elementWiseStride(hShapePointers[i]) <= 0 || -// shape::order(hShapePointers[i]) != 'c') { -// isVstack = false; -// break; -// } -// } -// } -// -// // let's try to fit N-dimensional vstack -// if (!isVstack && !isScalar && dimension == 0 && shape::order(hXShapeInfo) == 'c') { -// auto length0 = shape::length(hShapePointers[0]); -// isVstack = true; -// for (int i = 0; i < numArrays; i++) { -// if (shape::elementWiseStride(hShapePointers[i]) <= 0 || shape::order(hShapePointers[i]) != 'c' || length0 != shape::length(hShapePointers[i])) { -// isVstack = false; -// break; -// } -// } -// } -// -// if (!isScalar && !isVstack && dimension == 1 && shape::isVector(hZShapeInfo)) { -// isHstack = true; -// for (int i = 0; i < numArrays; i++) { -// if (!shape::isVector(hShapePointers[i]) || shape::elementWiseStride(hShapePointers[i]) <= 0) { -// isHstack = false; -// break; -// } -// } -// } -// -// if (isScalar) { -// if (nd4j::Environment::getInstance()->isDebugAndVerbose()) -// printf("Going scalar concat\n"); -// -// dim3 launchDims(128, 128, 16384); -// auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); -// BUILD_SINGLE_SELECTOR(zType, concatKernelScalarGeneric, (launchDims, stream, numArrays, reinterpret_cast(ddata[0]), dZ), LIBND4J_TYPES); -// -// } else if (isVstack) { -// if (nd4j::Environment::getInstance()->isDebugAndVerbose()) -// printf("Going VStack concat\n"); -// -// dim3 launchDims(128, 512, 16384); -// auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); -// BUILD_SINGLE_SELECTOR(zType, concatKernelVStackGeneric, (launchDims, stream, numArrays, reinterpret_cast(ddata[0]), reinterpret_cast(dinputShapeInfo[0]), dZ, dZShapeInfo), LIBND4J_TYPES); -// -// } else if (isHstack) { -// if (nd4j::Environment::getInstance()->isDebugAndVerbose()) -// printf("Going HStack concat\n"); -// -// dim3 launchDims(128, 128, 16384); -// auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); -// BUILD_SINGLE_SELECTOR(zType, concatKernelHStackGeneric, (launchDims, stream, numArrays, reinterpret_cast(ddata[0]), reinterpret_cast(dinputShapeInfo[0]), dZ, dZShapeInfo), LIBND4J_TYPES); -// } else { -// if (nd4j::Environment::getInstance()->isDebugAndVerbose()) -// printf("Going generic concat\n"); -// -// auto devZTadShape = reinterpret_cast(extraPointers[10]); -// auto devZOffsets = reinterpret_cast(extraPointers[11]); -// -// dim3 launchDims(128, 128, 8192); -// auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); -// BUILD_SINGLE_SELECTOR(zType, concatKernelGeneric, (launchDims, stream, numArrays, reinterpret_cast(ddata[0]), reinterpret_cast(dinputShapeInfo[0]), dZ, dZShapeInfo, reinterpret_cast(tadPointers[0]), reinterpret_cast(offsetPointers[0]), devZTadShape, devZOffsets), LIBND4J_TYPES); -// } -// if (nd4j::Environment::getInstance()->isDebugAndVerbose()) -// printf("sharedMemory requested for concatFloat: [%i], registers: [%i]\n", smem, funcAttributes[31].numRegs); -// -// cudaError_t res = cudaStreamSynchronize(*stream); -// checkCudaErrors(res); -// nd4j::DebugHelper::checkErrorCode(stream, "Legacy ConcatFloat(...) failed"); -//} - - - void specialConcat( Nd4jPointer *extraPointers, int dimension, @@ -1432,8 +1255,14 @@ void specialConcat( Nd4jPointer *inputShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { - - BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), nd4j::SpecialMethods ,::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES); + try { + BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), nd4j::SpecialMethods, + ::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo), + LIBND4J_TYPES); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -1441,9 +1270,15 @@ void specialConcat( * This method saves */ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) { - auto pack = new TadPack(); - *pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength); - return pack; + try { + auto pack = new TadPack(); + *pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength); + return pack; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) { @@ -1489,11 +1324,11 @@ int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, } break; } - //cudaError_t dZ = cudaMemcpyAsync((void *) dst, (const void *) src, (size_t) size, kind, *pStream); - cudaError_t dZ = cudaMemcpyToSymbolAsync(deviceConstantMemory, const_cast(src), size, dst, kind, *pStream); - checkCudaErrors(dZ); - if (dZ != 0) - throw std::runtime_error("cudaMemcpyToSymbolAsync(...) failed"); + auto dZ = cudaMemcpyToSymbolAsync(deviceConstantMemory, const_cast(src), size, dst, kind, *pStream); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMemcpyToSymbolAsync failed"); + } return 1; } @@ -1502,8 +1337,10 @@ Nd4jPointer getConstantSpace() { Nd4jPointer dConstAddr; cudaError_t dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), deviceConstantMemory); - if (dZ != 0) - throw std::runtime_error("cudaGetSymbolAddress(...) failed"); + if (dZ != 0) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(dZ); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaGetSymbolAddress failed"); + } return dConstAddr; } @@ -1519,13 +1356,19 @@ void pullRows(Nd4jPointer *extraPointers, Nd4jLong *tadOffsets, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + dim3 launchDims(64, 256, 1024); + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR(xType, pullRowsKernelGeneric, + (launchDims, stream, dX, dZ, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), + LIBND4J_TYPES); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - dim3 launchDims(64, 256, 1024); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, pullRowsKernelGeneric, (launchDims, stream, dX, dZ, n, indexes, tadShapeInfo, tadOffsets, zTadShapeInfo, zTadOffsets), LIBND4J_TYPES); - - DEBUG_KERNEL(stream, -1); + DEBUG_KERNEL(stream, -1); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -1537,25 +1380,31 @@ void average(Nd4jPointer *extras, int n, Nd4jLong length, bool propagate) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); + int mode = getDeviceId(extras[3]); - cudaStream_t * stream = reinterpret_cast(extras[1]); - int mode = getDeviceId(extras[3]); + auto dX = reinterpret_cast(dx); - auto dX = reinterpret_cast(dx); + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("averageFloat called\n"); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("averageFloat called\n"); - - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - // launching on gpu - if (mode == 0) { - dim3 launchDims(256, 256, 4096); - BUILD_SINGLE_SELECTOR(xType, averagingKernelGeneric, (launchDims, stream, dX, dz, n, length, propagate), LIBND4J_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "AverageFloat(...) failed"); - } else { - // launching on host memory - BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::averageGeneric(x, z, zShapeInfo, n, length, propagate), LIBND4J_TYPES); - } + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); + // launching on gpu + if (mode == 0) { + dim3 launchDims(256, 256, 4096); + BUILD_SINGLE_SELECTOR(xType, averagingKernelGeneric, (launchDims, stream, dX, dz, n, length, propagate), + LIBND4J_TYPES); + nd4j::DebugHelper::checkErrorCode(stream, "AverageFloat(...) failed"); + } else { + // launching on host memory + BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::averageGeneric(x, z, zShapeInfo, n, length, propagate), + LIBND4J_TYPES); + } + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void accumulate(Nd4jPointer *extras, @@ -1565,25 +1414,31 @@ void accumulate(Nd4jPointer *extras, void *dz, Nd4jLong *dzShapeInfo, int n, Nd4jLong length) { + try { + auto stream = reinterpret_cast(extras[1]); + int mode = getDeviceId(extras[3]); - auto stream = reinterpret_cast(extras[1]); - int mode = getDeviceId(extras[3]); + auto dX = reinterpret_cast(dx); - auto dX = reinterpret_cast(dx); + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("accumulateFloat called\n"); + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("accumulateFloat called\n"); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - - // launching on gpu - if (mode == 0) { - dim3 launchDims(n, 256, 16384); - BUILD_SINGLE_SELECTOR(xType, accumulateKernelGeneric, (launchDims, stream, dX, dz, n,length), LIBND4J_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "AccumulateFloat(...) failed"); - } else { - // launching on host memory - BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::accumulateGeneric(x, z, zShapeInfo, n, length), LIBND4J_TYPES); - } + // launching on gpu + if (mode == 0) { + dim3 launchDims(n, 256, 16384); + BUILD_SINGLE_SELECTOR(xType, accumulateKernelGeneric, (launchDims, stream, dX, dz, n, length), + LIBND4J_TYPES); + nd4j::DebugHelper::checkErrorCode(stream, "AccumulateFloat(...) failed"); + } else { + // launching on host memory + BUILD_SINGLE_SELECTOR(xType, nd4j::SpecialMethods, ::accumulateGeneric(x, z, zShapeInfo, n, length), + LIBND4J_TYPES); + } + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -1596,50 +1451,29 @@ void shuffle(Nd4jPointer *extras, int *shuffleMap, Nd4jPointer *tadShapeInfo, Nd4jPointer *tadOffsets) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); - cudaStream_t *stream = reinterpret_cast(extras[1]); + auto dX = reinterpret_cast(dx); + auto dZ = reinterpret_cast(dz); + auto xShape = reinterpret_cast(xShapeInfo); + auto dxShape = reinterpret_cast(dXShapeInfo); + auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); + auto tadOffset = reinterpret_cast(tadOffsets); - auto dX = reinterpret_cast(dx); - auto dZ = reinterpret_cast(dz); - auto xShape = reinterpret_cast(xShapeInfo); - auto dxShape = reinterpret_cast(dXShapeInfo); - auto tadOnlyShapeInfo = reinterpret_cast(tadShapeInfo); - auto tadOffset = reinterpret_cast(tadOffsets); + auto xType = nd4j::ArrayOptions::dataType(xShape[0]); + dim3 launchDims(256, 512, 8192); + BUILD_SINGLE_SELECTOR(xType, shuffleKernelGeneric, + (launchDims, stream, dX, dxShape, dZ, N, shuffleMap, tadOnlyShapeInfo, tadOffset), + LIBND4J_TYPES); - auto xType = nd4j::ArrayOptions::dataType(xShape[0]); - dim3 launchDims(256, 512, 8192); - BUILD_SINGLE_SELECTOR(xType, shuffleKernelGeneric, (launchDims, stream, dX, dxShape, dZ, N, shuffleMap, tadOnlyShapeInfo, tadOffset), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "shuffle(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "shuffle(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } -/* -void execMetaPredicateShape(Nd4jPointer *extras, - const int opTypeA, - const int opNumA, - const int opTypeB, - const int opNumB, - Nd4jLong N, - void *hX, Nd4jLong *hXShapeInfo, - void *dX, Nd4jLong *dXShapeInfo, - void *hY, Nd4jLong *hYShapeInfo, - void *dY, Nd4jLong *dYShapeInfo, - void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo, - void *extraA, - void *extraB, - double scalarA, - double scalarB) { - - cudaStream_t *stream = reinterpret_cast(extras[1]); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, functions::grid::GRIDShaped, ::execMetaPredicateShaped(stream, extras, opTypeA, opNumA, opTypeB, opNumB, N, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraA, extraB, scalarA, scalarB), LIBND4J_TYPES); - // functions::grid::GRIDShaped::execMetaPredicateShaped(stream, extras, opTypeA, opNumA, opTypeB, opNumB, N, dX, dXShapeInfo, dy, dYShapeInfo, dz, zShapeInfo, extraA, extraB, scalarA, scalarB); - - DEBUG_KERNEL(stream, opNumA); -} -*/ - bool isExperimentalEnabled() { return nd4j::Environment::getInstance()->isExperimentalBuild(); } @@ -1670,9 +1504,14 @@ void execSummaryStats(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, bool biasCorrected) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, + hZShapeInfo, dZ, dZShapeInfo, biasCorrected); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1686,11 +1525,18 @@ void execSummaryStatsTad(Nd4jPointer *extraPointers, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, bool biasCorrected, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, biasCorrected); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execSummaryStats(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, + hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, + tadOffsets, biasCorrected); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1703,9 +1549,14 @@ void execReduce3(Nd4jPointer *extraPointers, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1721,35 +1572,35 @@ void execReduce3Tad(Nd4jPointer *extraPointers, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - // if (extraPointers == nullptr || extraPointers[2] == 0) - // NativeOpExecutioner::execReduce3(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - // else { - // // going tad-ways - // auto tadShapeInfo = reinterpret_cast (extraPointers[0]); - // auto tadOffsets = reinterpret_cast(extraPointers[1]); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, + reinterpret_cast(hDimension), + shape::length(hDimensionShape)); + auto tadLength = shape::length(tadPack.primaryShapeInfo()); + auto yLength = shape::length(hYShapeInfo); + auto xLength = shape::length(hXShapeInfo); - // NativeOpExecutioner::execReduce3TAD(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets); - // } + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - // nd4j_printf("Starting...\n",""); + if (tadLength == yLength || tadLength == xLength) { + // nd4j_printf("== way\n",""); + NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, + dY, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, + tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); + } else + NativeOpExecutioner::execReduce3TAD(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, + dimension, dimensionLength, tadOnlyShapeInfo, yTadOffsets, + yTadOnlyShapeInfo, yTadOffsets); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast(hDimension), shape::length(hDimensionShape)); - auto tadLength = shape::length(tadPack.primaryShapeInfo()); - auto yLength = shape::length(hYShapeInfo); - auto xLength = shape::length(hXShapeInfo); - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - - if (tadLength == yLength || tadLength == xLength) { - // nd4j_printf("== way\n",""); - NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, - tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets); - } else - NativeOpExecutioner::execReduce3TAD(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, yTadOffsets, yTadOnlyShapeInfo, yTadOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1761,9 +1612,14 @@ void execReduce3Scalar(Nd4jPointer *extraPointers,int opNum, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3Scalar(&lc, opNum,hX,hXShapeInfo,dX, dXShapeInfo,extraParams,hY,hYShapeInfo,dY,dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduce3Scalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1776,9 +1632,15 @@ void execScalarBool(Nd4jPointer *extraPointers, void *hScalar, Nd4jLong *hScalarShapeInfo, void *dScalar, Nd4jLong *dScalarShapeInfo, void *extraParams) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, extraParams); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, + extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1794,11 +1656,19 @@ void execScalarBoolTad(Nd4jPointer *extraPointers, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, hScalars, hScalarShapeInfo, dScalars, dScalarShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execScalarBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, + dZ, dZShapeInfo, hScalars, hScalarShapeInfo, dScalars, dScalarShapeInfo, + dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, + tadOffsetsZ); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1811,9 +1681,14 @@ void execScalar(Nd4jPointer *extraPointers, void *hScalar, Nd4jLong *hScalarShapeInfo, void *dScalar, Nd4jLong *dScalarShapeInfo, void *extraParams) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, extraParams); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execScalar(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, + hScalar, hScalarShapeInfo, dScalar, dScalarShapeInfo, extraParams); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1830,27 +1705,36 @@ void execScalarTad(Nd4jPointer *extraPointers, void *dDimension, Nd4jLong *dDimensionShape, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); - auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (yType != xType && yType != nd4j::DataType::BOOL && !isExperimentalEnabled()) - throw nd4j::datatype_exception::build("execScalar both operands must have same data type", xType, yType); + if (yType != xType && yType != nd4j::DataType::BOOL && !isExperimentalEnabled()) + throw nd4j::datatype_exception::build("execScalar both operands must have same data type", xType, yType); - dim3 launchDims(256, 256, 16384); + dim3 launchDims(256, 256, 16384); #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, + ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, + dZShapeInfo, dScalars, extraParams, dimension, + dimensionLength, tadShapeInfo, tadOffsets, + tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES); #endif - DEBUG_KERNEL(stream, opNum); + DEBUG_KERNEL(stream, opNum); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void execAggregate(Nd4jPointer *extraPointers, @@ -1866,16 +1750,23 @@ void execAggregate(Nd4jPointer *extraPointers, void *realArguments, int numRealArguments, nd4j::DataType dtype) { + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + int numBlocks = getDeviceId(extraPointers[2]); + int numThreads = getDeviceId(extraPointers[3]); + int shmem = getDeviceId(extraPointers[4]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - int numBlocks = getDeviceId(extraPointers[2]); - int numThreads = getDeviceId(extraPointers[3]); - int shmem = getDeviceId(extraPointers[4]); + dim3 launchDims = dim3(numBlocks, numThreads, shmem); - dim3 launchDims = dim3(numBlocks, numThreads, shmem); - - BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction, ::aggregateKernelGeneric(launchDims, stream, opNum, arguments, numArguments, shapes, numShapes, indexArguments, numIndexArguments, intArrays, numIntArrays, realArguments, numRealArguments), FLOAT_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed"); + BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction, + ::aggregateKernelGeneric(launchDims, stream, opNum, arguments, numArguments, shapes, + numShapes, indexArguments, numIndexArguments, intArrays, + numIntArrays, realArguments, numRealArguments), FLOAT_TYPES); + nd4j::DebugHelper::checkErrorCode(stream, "execAggregateFloat(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void batchExecutor(Nd4jPointer *extraPointers, @@ -1897,17 +1788,25 @@ void execAggregateBatch(Nd4jPointer *extraPointers, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals, void *ptrToArguments, nd4j::DataType dtype) { - // not implemented yet - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - int numBlocks = getDeviceId(extraPointers[2]); - int numThreads = getDeviceId(extraPointers[3]); - int shmem = getDeviceId(extraPointers[4]); + try { + // not implemented yet + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + int numBlocks = getDeviceId(extraPointers[2]); + int numThreads = getDeviceId(extraPointers[3]); + int shmem = getDeviceId(extraPointers[4]); - dim3 launchDims = dim3(numAggregates, numThreads, shmem); + dim3 launchDims = dim3(numAggregates, numThreads, shmem); - BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction, ::aggregateBatchKernelGeneric(launchDims, stream, opNum, numAggregates, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, ptrToArguments), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(dtype, functions::aggregate::AggregatedFunction, + ::aggregateBatchKernelGeneric(launchDims, stream, opNum, numAggregates, maxArgs, + maxShapes, maxIntArrays, maxIntArraySize, maxIdx, maxReals, + ptrToArguments), FLOAT_TYPES); - DEBUG_KERNEL(stream, opNum); + DEBUG_KERNEL(stream, opNum); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1917,9 +1816,13 @@ void execRandom(Nd4jPointer *extraPointers, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraArguments) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1929,9 +1832,14 @@ void execRandom2(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraArguments) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, + dZShapeInfo, extraArguments); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -1943,9 +1851,14 @@ void execRandom3(Nd4jPointer *extraPointers, int opNum, Nd4jPointer stateHost, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraArguments) { - - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -2053,13 +1966,19 @@ void tear(Nd4jPointer *extras, Nd4jLong *zShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); + dim3 launchDims(512, 512, 512); + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, + (launchDims, stream, dX, dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), + LIBND4J_TYPES); - cudaStream_t *stream = reinterpret_cast(extras[1]); - dim3 launchDims(512, 512, 512); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, tearKernelGeneric, (launchDims, stream, dX, dXShapeInfo, targets, zShapeInfo, tadShapeInfo, tadOffsets), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "tearFloat(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -2146,56 +2065,72 @@ void prescanArrayRecursive(Nd4jPointer *extras, int *dZ, int *dX, int numElement void encodeThresholdP1(Nd4jPointer *extras, void *dx, Nd4jLong *hXShapeInfo, Nd4jLong N, int *dz, float threshold) { + try { + cudaStream_t *stream = reinterpret_cast(extras[1]); - cudaStream_t *stream = reinterpret_cast(extras[1]); + int blockSize = 1024; + int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); - int blockSize = 1024; - int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); + dim3 launchDims(numBlocks, blockSize, 1024); + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, encoderKernelP1Generic, (launchDims, stream, dx, N, dz, threshold), LIBND4J_TYPES); - dim3 launchDims(numBlocks, blockSize, 1024); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, encoderKernelP1Generic, (launchDims, stream, dx, N, dz, threshold), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "encodeThresholdP1Float(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "encodeThresholdP1Float(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void encodeThresholdP2Int(Nd4jPointer *extraPointers, int *dx, Nd4jLong N, int *dz) { - - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - //encoderKernelP2Float<<>>(dx, N, dz); - prescanArrayRecursive(extraPointers, dz, dx + 1, (int) N, 0); - nd4j::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed"); + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + //encoderKernelP2Float<<>>(dx, N, dz); + prescanArrayRecursive(extraPointers, dz, dx + 1, (int) N, 0); + nd4j::DebugHelper::checkErrorCode(stream, "encodeThresholdP2Int(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void encodeThresholdP3(Nd4jPointer *extraPointers, void *dx, Nd4jLong *hXShapeInfo, int *offsets, Nd4jLong N, int *dz){ + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + int blockSize = 1024; + int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); - int blockSize = 1024; - int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); + dim3 launchDims(numBlocks, blockSize, 4096); + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, (launchDims, stream, dx, offsets, N, dz), LIBND4J_TYPES); - dim3 launchDims(numBlocks, blockSize, 4096); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, encoderKernelP3Generic, (launchDims, stream, dx, offsets, N, dz), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "encodeThresholdP3Float(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void decodeThreshold(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo){ + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + // we probably want to have smaller blocks here, memory writes are misaligned anyway + int blockSize = 128; + int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); - // we probably want to have smaller blocks here, memory writes are misaligned anyway - int blockSize = 128; - int numBlocks = N / blockSize + (N % blockSize ? 1 : 0); + dim3 launchDims(numBlocks, blockSize, 1024); + auto zType = nd4j::ArrayOptions::dataType(zShapeInfo); + BUILD_SINGLE_SELECTOR(zType, decoderKernelGeneric, (launchDims, stream, dx, N, dz), LIBND4J_TYPES); - dim3 launchDims(numBlocks, blockSize, 1024); - auto zType = nd4j::ArrayOptions::dataType(zShapeInfo); - BUILD_SINGLE_SELECTOR(zType, decoderKernelGeneric, (launchDims, stream, dx, N, dz), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "decodeThresholdFloat(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "decodeThresholdFloat(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } //////////////////////////////////////////////////////////////////////// @@ -2212,11 +2147,18 @@ void execReduce3All(Nd4jPointer *extraPointers, void *dDimension, Nd4jLong *dDimensionShape, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { - auto dimension = reinterpret_cast(dDimension); - int dimensionLength = static_cast(shape::length(hDimensionShape)); + try { + auto dimension = reinterpret_cast(dDimension); + int dimensionLength = static_cast(shape::length(hDimensionShape)); - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execReduce3All(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + NativeOpExecutioner::execReduce3All(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParamsVals, hY, + hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, + dimensionLength, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -2224,57 +2166,65 @@ void sort(Nd4jPointer *extraPointers, void *x, Nd4jLong *xShapeInfo, void *dX, Nd4jLong *dXShapeInfo, bool descending) { + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); + auto xLength = shape::length(xShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = nd4j::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; + // check if xLength is a power of 2, and use bitonic sort, if that's the case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { + int numThreads = nd4j::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) + numBlocks++; - dim3 launchDims(numBlocks, numThreads, 32768); + dim3 launchDims(numBlocks, numThreads, 32768); - for (int k = 2; k <= xLength; k = 2*k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_SINGLE_SELECTOR(xType, bitonicSortStepGeneric, (launchDims, stream, dX, dXShapeInfo, j, k, xLength, descending), LIBND4J_TYPES); - } + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_SINGLE_SELECTOR(xType, bitonicSortStepGeneric, + (launchDims, stream, dX, dXShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES); + } + } + } else { + int numThreads = nd4j::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) + numBlocks++; + + numBlocks = nd4j::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_SINGLE_SELECTOR(xType, bitonicArbitraryStepGeneric, + (launchDims, stream, dX, dXShapeInfo, n, xLength, rev, descending), + LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); + } } - } else { - int numThreads = nd4j::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - numBlocks = nd4j::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window<<=1) { - int n = window; - int rev = 0; - do{ - int half = n >> 1; - BUILD_SINGLE_SELECTOR(xType, bitonicArbitraryStepGeneric, (launchDims, stream, dX, dXShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES); - n>>=1; - rev = 1; - } while(n > 1); - } + nd4j::DebugHelper::checkErrorCode(stream, "sort(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } - - nd4j::DebugHelper::checkErrorCode(stream, "sort(...) failed"); } @@ -2284,55 +2234,64 @@ void sortByKey(Nd4jPointer *extraPointers, void *y, Nd4jLong *yShapeInfo, void *dy, Nd4jLong *dyShapeInfo, bool descending) { + try { + auto stream = reinterpret_cast(extraPointers[1]); - auto stream = reinterpret_cast(extraPointers[1]); - - auto xLength = shape::length(xShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(yShapeInfo); + auto xLength = shape::length(xShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(yShapeInfo); - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = nd4j::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; + // check if xLength is a power of 2, and use bitonic sort, if that's the case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { + int numThreads = nd4j::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) + numBlocks++; - dim3 launchDims(numBlocks, numThreads, 32768); + dim3 launchDims(numBlocks, numThreads, 32768); - for (int k = 2; k <= xLength; k = 2*k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } + } + } else { + int numThreads = nd4j::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) + numBlocks++; + + numBlocks = nd4j::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); } } - } else { - int numThreads = nd4j::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - numBlocks = nd4j::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window<<=1) { - int n = window; - int rev = 0; - do{ - int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES); - n>>=1; - rev = 1; - } while(n > 1); - } + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } } @@ -2342,54 +2301,63 @@ void sortByValue(Nd4jPointer *extraPointers, void *y, Nd4jLong *yShapeInfo, void *dy, Nd4jLong *dyShapeInfo, bool descending) { - auto stream = reinterpret_cast(extraPointers[1]); + try { + auto stream = reinterpret_cast(extraPointers[1]); - auto xLength = shape::length(xShapeInfo); - auto xEWS = shape::elementWiseStride(xShapeInfo); - auto xType = nd4j::ArrayOptions::dataType(yShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(xShapeInfo); + auto xLength = shape::length(xShapeInfo); + auto xEWS = shape::elementWiseStride(xShapeInfo); + auto xType = nd4j::ArrayOptions::dataType(yShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(xShapeInfo); - // check if xLength is a power of 2, and use bitonic sort, if that's the case - if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { - int numThreads = nd4j::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; + // check if xLength is a power of 2, and use bitonic sort, if that's the case + if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { + int numThreads = nd4j::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) + numBlocks++; - dim3 launchDims(numBlocks, numThreads, 32768); + dim3 launchDims(numBlocks, numThreads, 32768); - for (int k = 2; k <= xLength; k = 2*k) { - for (int j = k >> 1; j > 0; j = j >> 1) { - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES); + for (int k = 2; k <= xLength; k = 2 * k) { + for (int j = k >> 1; j > 0; j = j >> 1) { + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, j, k, xLength, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + } + } + } else { + int numThreads = nd4j::math::nd4j_min(512, xLength); + int numBlocks = xLength / numThreads; + if (xLength % numThreads > 0 || numBlocks == 0) + numBlocks++; + + numBlocks = nd4j::math::nd4j_min(512, numBlocks); + dim3 launchDims(numBlocks, numThreads, 32768); + + int max = 2, dg = 0; + while (max < xLength) { + max <<= 1; + dg++; + } + max <<= 1; + + for (int window = 2; window < max; window <<= 1) { + int n = window; + int rev = 0; + do { + int half = n >> 1; + BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), + LIBND4J_TYPES, LIBND4J_TYPES); + n >>= 1; + rev = 1; + } while (n > 1); } } - } else { - int numThreads = nd4j::math::nd4j_min(512, xLength); - int numBlocks = xLength / numThreads; - if (xLength % numThreads > 0 || numBlocks == 0) - numBlocks++; - - numBlocks = nd4j::math::nd4j_min(512, numBlocks); - dim3 launchDims(numBlocks, numThreads, 32768); - - int max = 2, dg = 0; - while (max < xLength) { - max <<= 1; - dg++; - } - max <<= 1; - - for (int window = 2; window < max; window<<=1) { - int n = window; - int rev = 0; - do{ - int half = n >> 1; - BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES); - n>>=1; - rev = 1; - } while(n > 1); - } + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } } @@ -2403,15 +2371,23 @@ void sortTadByKey(Nd4jPointer *extraPointers, int *dimension, int dimensionLength, bool descending) { - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast(extraPointers[0]); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(yShapeInfo); - BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES); + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(yShapeInfo); + BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, + (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), + LIBND4J_TYPES, LIBND4J_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void sortTadByValue(Nd4jPointer *extraPointers, @@ -2422,16 +2398,24 @@ void sortTadByValue(Nd4jPointer *extraPointers, int *dimension, int dimensionLength, bool descending) { - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast(extraPointers[0]); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); - auto xType = nd4j::ArrayOptions::dataType(yShapeInfo); - auto yType = nd4j::ArrayOptions::dataType(xShapeInfo); + try { + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048); + auto xType = nd4j::ArrayOptions::dataType(yShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(xShapeInfo); - BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, + (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), + LIBND4J_TYPES, LIBND4J_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } @@ -2443,15 +2427,23 @@ void sortTad(Nd4jPointer *extraPointers, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending) { - // to be implemented - auto stream = reinterpret_cast(extraPointers[1]); - auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast(extraPointers[0]); - auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768); - auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); - BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES); + try { + // to be implemented + auto stream = reinterpret_cast(extraPointers[1]); + auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext() + : reinterpret_cast(extraPointers[0]); + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); + dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768); + auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); + BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, + (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), + LIBND4J_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "sortTad(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "sortTad(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) { @@ -2464,21 +2456,29 @@ Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, Nd4jLong N, int *dz, float threshold) { + try { - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - int *resultPointer = reinterpret_cast(extraPointers[2]); - int *reductionPointer = reinterpret_cast(extraPointers[3]); + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + int *resultPointer = reinterpret_cast(extraPointers[2]); + int *reductionPointer = reinterpret_cast(extraPointers[3]); - dim3 launchDims(512, 512, 32768); - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); - BUILD_SINGLE_SELECTOR(xType, cudaEncodeBitmapGeneric, (launchDims, stream, dx, N, dz, resultPointer, reductionPointer, threshold), LIBND4J_TYPES); + dim3 launchDims(512, 512, 32768); + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + BUILD_SINGLE_SELECTOR(xType, cudaEncodeBitmapGeneric, + (launchDims, stream, dx, N, dz, resultPointer, reductionPointer, threshold), + LIBND4J_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "encodeBitmapFloat(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "encodeBitmapFloat(...) failed"); - Nd4jLong dZ = (Nd4jLong) resultPointer[0]; - resultPointer[0] = 0; + Nd4jLong dZ = (Nd4jLong) resultPointer[0]; + resultPointer[0] = 0; - return dZ; + return dZ; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 0; + } } @@ -2486,13 +2486,17 @@ void decodeBitmap(Nd4jPointer *extraPointers, void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo) { + try { + cudaStream_t *stream = reinterpret_cast(extraPointers[1]); + dim3 launchDims(512, 512, 16384); + auto xType = nd4j::ArrayOptions::dataType(zShapeInfo); + BUILD_SINGLE_SELECTOR(xType, cudaDecodeBitmapGeneric, (launchDims, stream, dx, N, dz), LIBND4J_TYPES); - cudaStream_t *stream = reinterpret_cast(extraPointers[1]); - dim3 launchDims(512, 512, 16384); - auto xType = nd4j::ArrayOptions::dataType(zShapeInfo); - BUILD_SINGLE_SELECTOR(xType, cudaDecodeBitmapGeneric, (launchDims, stream, dx, N, dz), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "decodeBitmapFloat(...) failed"); + nd4j::DebugHelper::checkErrorCode(stream, "decodeBitmapFloat(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { @@ -2505,7 +2509,13 @@ void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length) { nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer) { - return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + try { + return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) { @@ -2560,9 +2570,16 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D } nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs); + return _calculateOutputShapes(extraPointers, op, inputBuffers, inputShapes, numInputShapes, tArgs, numTArgs, + iArgs, numIArgs, bArgs, numBArgs); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::DeclarableOp* op, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { @@ -2584,9 +2601,15 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D } nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); + return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getShapeListSize(nd4j::ShapeList* list) { @@ -2681,39 +2704,59 @@ static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + return realExec(op, extraPointers, hash, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, + numOutputs, tArgs, numTArgs, iArgs, numIArgs, bArgs, numBArgs, isInplace); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; + } } int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext) { - auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); - auto context = reinterpret_cast(opContext); + try { + auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation(hash); + auto context = reinterpret_cast(opContext); - auto result = op->execute(context); + auto result = op->execute(context); - auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream()); - if (res != 0) - throw nd4j::cuda_exception::build("customOp execution failed", res); + auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream()); + if (res != 0) + throw nd4j::cuda_exception::build("customOp execution failed", res); - for (auto v:context->fastpath_in()) { - v->syncToDevice(); + for (auto v:context->fastpath_in()) { + if (!v->isEmpty()) + v->syncToDevice(); + } + + for (auto v:context->fastpath_out()) { + if (!v->isEmpty()) + v->syncToDevice(); + } + + return result; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; } - - for (auto v:context->fastpath_out()) { - v->syncToDevice(); - } - - return result; } int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer) { + try { + auto graph = nd4j::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); - auto graph = nd4j::graph::GraphExecutioner::importFromFlatPointer(flatBufferPointer); + nd4j::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); - nd4j::graph::GraphHolder::getInstance()->registerGraph(graphId, graph); - - return ND4J_STATUS_OK; + return ND4J_STATUS_OK; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; + } } @@ -2764,7 +2807,13 @@ static VariablesSet* executeStoredGraphT(Nd4jPointer *extraPointers, Nd4jLong gr } VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs) { - return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); + try { + return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getVariablesSetSize(nd4j::graph::VariablesSet* set) { @@ -2800,10 +2849,15 @@ void* getVariableBuffer(nd4j::graph::Variable* variable) { } int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { + try { + nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId); - nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId); - - return ND4J_STATUS_OK; + return ND4J_STATUS_OK; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; + } } void deletePointerArray(Nd4jPointer pointer) { @@ -2918,8 +2972,15 @@ Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, nd4j::graph::GraphS Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPointer state, Nd4jLong opHash, Nd4jLong *scopes, int numScopes, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int numInputs, Nd4jPointer *outputBuffers, Nd4jPointer *outputShapes, int numOutputs) { - - return execCustomOpWithScope(extraPointers, reinterpret_cast(state), opHash, scopes, numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, numOutputs); + try { + return execCustomOpWithScope(extraPointers, reinterpret_cast(state), opHash, scopes, + numScopes, inputBuffers, inputShapes, numInputs, outputBuffers, outputShapes, + numOutputs); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return 1; + } } void deleteResultWrapper(Nd4jPointer ptr) { @@ -2937,181 +2998,186 @@ int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer dX, Nd4jLong *dXSh * void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, long N, int dstType, Nd4jPointer dZ); */ void convertTypes(Nd4jPointer *extras, int srcType, Nd4jPointer dX, Nd4jLong N, int dstType, Nd4jPointer dZ) { - auto dx = reinterpret_cast(dX); - auto dz = reinterpret_cast(dZ); + try { + auto dx = reinterpret_cast(dX); + auto dz = reinterpret_cast(dZ); - if (srcType == ND4J_FLOAT8) { - if (dstType == ND4J_FLOAT8) { - // convertKernel(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { + if (srcType == ND4J_FLOAT8) { + if (dstType == ND4J_FLOAT8) { + // convertKernel(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { - } else if (dstType == ND4J_FLOAT32) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT8) { - if (dstType == ND4J_FLOAT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - //convertKernel(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: eventually we might want to add it - } else if (dstType == ND4J_FLOAT32) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_UINT8) { - if (dstType == ND4J_FLOAT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: still might want to add - } else if (dstType == ND4J_FLOAT32) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT16) { - if (dstType == ND4J_FLOAT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO: .... ^^^ - } else if (dstType == ND4J_FLOAT32) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_THRESHOLD) { - //nd4j::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_INT16) { - if (dstType == ND4J_FLOAT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { - // TODO... - } else if (dstType == ND4J_FLOAT32) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else { - printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_FLOAT24) { + } else if (dstType == ND4J_FLOAT32) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_INT8) { + if (dstType == ND4J_FLOAT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + //convertKernel(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: eventually we might want to add it + } else if (dstType == ND4J_FLOAT32) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_UINT8) { + if (dstType == ND4J_FLOAT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: still might want to add + } else if (dstType == ND4J_FLOAT32) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_FLOAT16) { + if (dstType == ND4J_FLOAT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO: .... ^^^ + } else if (dstType == ND4J_FLOAT32) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_THRESHOLD) { + //nd4j::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_INT16) { + if (dstType == ND4J_FLOAT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { + // TODO... + } else if (dstType == ND4J_FLOAT32) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else { + printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_FLOAT24) { - } else if (srcType == ND4J_FLOAT32) { - if (dstType == ND4J_FLOAT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { + } else if (srcType == ND4J_FLOAT32) { + if (dstType == ND4J_FLOAT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { - } else if (dstType == ND4J_DOUBLE) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_THRESHOLD) { - //nd4j::convertToThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else if (srcType == ND4J_DOUBLE) { - if (dstType == ND4J_FLOAT8) { - //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT8) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_INT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_UINT16) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_FLOAT24) { + } else if (dstType == ND4J_DOUBLE) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_THRESHOLD) { + //nd4j::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_DOUBLE) { + if (dstType == ND4J_FLOAT8) { + //nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT8) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_INT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_UINT16) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_FLOAT24) { - } else if (dstType == ND4J_FLOAT32) { - nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - // - } else if (dstType == ND4J_THRESHOLD) { - //nd4j::convertToThreshold(nullptr, dx, N, dz); + } else if (dstType == ND4J_FLOAT32) { + nd4j::TypeCast::convertGenericCuda(extras, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + // + } else if (dstType == ND4J_THRESHOLD) { + //nd4j::convertToThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } + } else if (srcType == ND4J_THRESHOLD) { + if (dstType == ND4J_FLOAT16) { + //nd4j::convertFromThreshold(nullptr, dx, N, dz); + } else if (dstType == ND4J_FLOAT32) { + //nd4j::convertFromThreshold(nullptr, dx, N, dz); + } else if (dstType == ND4J_DOUBLE) { + //nd4j::convertFromThreshold(nullptr, dx, N, dz); + } else { + nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } } else { nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); } - } else if (srcType == ND4J_THRESHOLD) { - if (dstType == ND4J_FLOAT16) { - //nd4j::convertFromThreshold(nullptr, dx, N, dz); - } else if (dstType == ND4J_FLOAT32) { - //nd4j::convertFromThreshold(nullptr, dx, N, dz); - } else if (dstType == ND4J_DOUBLE) { - //nd4j::convertFromThreshold(nullptr, dx, N, dz); - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); - } - } else { - nd4j_printf("Unsupported types conversion: [%i] -> [%i]\n", srcType, dstType); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); } } @@ -3209,20 +3275,31 @@ void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void* hY, Nd4jLong* hYShapeInfo, Nd4jLong* hYOffsets, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, int* hIindexes, int* dIndexes) { + try { + auto stream = reinterpret_cast(extraPointers[1]); - auto stream = reinterpret_cast(extraPointers[1]); + nd4j::DataType type = ArrayOptions::dataType(hXShapeInfo); - nd4j::DataType type = ArrayOptions::dataType(hXShapeInfo); - - BUILD_SINGLE_SELECTOR(type, scatterUpdateCudaLauncher, (stream, opCode, numOfSubArrs, dX, dXShapeInfo, dXOffsets, dY, dYShapeInfo, dYOffsets, dIndexes), LIBND4J_TYPES); - nd4j::DebugHelper::checkErrorCode(stream, "scatterUpdate(...) failed"); + BUILD_SINGLE_SELECTOR(type, scatterUpdateCudaLauncher, + (stream, opCode, numOfSubArrs, dX, dXShapeInfo, dXOffsets, dY, dYShapeInfo, dYOffsets, dIndexes), + LIBND4J_TYPES); + nd4j::DebugHelper::checkErrorCode(stream, "scatterUpdate(...) failed"); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo) { - LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - auto p = reinterpret_cast(debugInfo); - NDArray array(buffer, specialBuffer, shapeInfo, &lc); - nd4j::DebugHelper::retrieveDebugStatistics(p, &array); + try { + LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); + auto p = reinterpret_cast(debugInfo); + NDArray array(buffer, specialBuffer, shapeInfo, &lc); + nd4j::DebugHelper::retrieveDebugStatistics(p, &array); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } void __global__ tryPointerKernel(void* p, int len) { @@ -3239,26 +3316,37 @@ void __global__ tryPointerKernel(void* p, int len) { } void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) { + try { + cudaStream_t stream; + cudaStreamCreate(&stream); - cudaStream_t stream; - cudaStreamCreate(&stream); + tryPointerKernel << < 256, 512, len + 64, stream >> > (p, len); + auto e = cudaStreamSynchronize(stream); - tryPointerKernel<<<256, 512, len+64, stream>>>(p, len); - auto e = cudaStreamSynchronize(stream); + if (e != 0) + throw nd4j::cuda_exception::build("tryPointer failed", e); - if (e != 0) - throw nd4j::cuda_exception::build("tryPointer failed", e); - - cudaStreamDestroy(stream); + cudaStreamDestroy(stream); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + } } int dataTypeFromNpyHeader(void *header) { return (int) cnpy::dataTypeFromHeader(reinterpret_cast(header)); } nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty) { - auto buffer = new ConstantDataBuffer(); - *buffer = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo(ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); - return buffer; + try { + auto buffer = new ConstantDataBuffer(); + *buffer = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo( + ShapeDescriptor(dtype, order, shape, strides, rank, ews, empty)); + return buffer; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) { @@ -3359,60 +3447,79 @@ void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) { Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - unsigned int shapeSize = arr.shape.size(); - std::vector shape(shapeSize); - bool _empty = false; - for(unsigned int i = 0; i < shapeSize; i++) { - shape[i] = arr.shape[i]; + try { + cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for (unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; - if (arr.shape[i] == 0) - _empty = true; + if (arr.shape[i] == 0) + _empty = true; + } + + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (shape.size() == 1 && shape[0] == 0) { + // scalar case + shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype); + } else if (_empty) { + if (shapeSize > 0) + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + } + return reinterpret_cast(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, + true)); + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; } - - auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); - - Nd4jLong *shapeBuffer; - if (shape.size() == 1 && shape[0] == 0) { - // scalar case - shapeBuffer = nd4j::ShapeBuilders::createScalarShapeInfo(dtype); - } else if (_empty) { - if (shapeSize > 0) - shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - else - shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype); - } else { - shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); - } - return reinterpret_cast(nd4j::ConstantShapeHelper::getInstance()->createFromExisting(shapeBuffer, true)); } const char* runLightBenchmarkSuit(bool printOut) { - nd4j::LightBenchmarkSuit suit; - auto result = suit.runSuit(); + try { + nd4j::LightBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) + nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length()+1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char) 0x0; - return chars; + return chars; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } const char* runFullBenchmarkSuit(bool printOut) { - nd4j::FullBenchmarkSuit suit; - auto result = suit.runSuit(); + try { + nd4j::FullBenchmarkSuit suit; + auto result = suit.runSuit(); - if (printOut) - nd4j_printf("%s\n", result.data()); + if (printOut) + nd4j_printf("%s\n", result.data()); - auto chars = new char[result.length()+1]; - std::memcpy(chars, result.data(), result.length()); - chars[result.length()] = (char) 0x0; + auto chars = new char[result.length() + 1]; + std::memcpy(chars, result.data(), result.length()); + chars[result.length()] = (char) 0x0; - return chars; + return chars; + } catch (std::exception &e) { + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); + nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); + return nullptr; + } } Nd4jLong getCachedMemory(int deviceId) { @@ -3449,4 +3556,12 @@ Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc) { Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc) { return lc->getCusolverHandle(); +} + +int lastErrorCode() { + return nd4j::LaunchContext::defaultContext()->errorReference()->errorCode(); +} + +const char* lastErrorMessage() { + return nd4j::LaunchContext::defaultContext()->errorReference()->errorMessage(); } \ No newline at end of file diff --git a/libnd4j/include/execution/ContextBuffers.h b/libnd4j/include/execution/ContextBuffers.h index 130354070..67c428d27 100644 --- a/libnd4j/include/execution/ContextBuffers.h +++ b/libnd4j/include/execution/ContextBuffers.h @@ -23,6 +23,7 @@ #include #include +#include namespace nd4j { class ND4J_EXPORT ContextBuffers { @@ -32,6 +33,7 @@ namespace nd4j { void* _allocationPointer = nullptr; void* _execStream = nullptr; void* _specialStream = nullptr; + sd::ErrorReference _errorReference; bool _allocated = false; bool _initialized = false; @@ -60,6 +62,8 @@ namespace nd4j { void setScalarBuffer(void* pointer); void setAllocationBuffer(void* pointer); + sd::ErrorReference* errorReference(); + void triggerOwnership(bool isOwner); int deviceId(); diff --git a/libnd4j/include/helpers/ProviderRNG.h b/libnd4j/include/execution/ErrorReference.h similarity index 57% rename from libnd4j/include/helpers/ProviderRNG.h rename to libnd4j/include/execution/ErrorReference.h index e82f6ac98..2b68d5855 100644 --- a/libnd4j/include/helpers/ProviderRNG.h +++ b/libnd4j/include/execution/ErrorReference.h @@ -15,32 +15,32 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 27.01.2018 +// @author raver119@gmail.com // -#ifndef LIBND4J_PROVIDERRNG_H -#define LIBND4J_PROVIDERRNG_H +#ifndef DEV_TESTS_ERRORREFERENCE_H +#define DEV_TESTS_ERRORREFERENCE_H -#include -#include - -namespace nd4j { - -class ProviderRNG { - - protected: - random::RandomBuffer* _rng; - static std::mutex _mutex; - ProviderRNG(); +#include +#include +namespace sd { + class ND4J_EXPORT ErrorReference { + private: + int _errorCode = 0; + std::string _errorMessage; public: - ProviderRNG(const ProviderRNG&) = delete; - void operator=(const ProviderRNG&) = delete; - random::RandomBuffer* getRNG() const; - static ProviderRNG& getInstance(); -}; + ErrorReference() = default; + ~ErrorReference() = default; + int errorCode(); + const char* errorMessage(); + void setErrorCode(int errorCode); + void setErrorMessage(std::string message); + void setErrorMessage(const char* message); + }; } -#endif //LIBND4J_PROVIDERRNG_H + +#endif //DEV_TESTS_ERRORREFERENCE_H diff --git a/libnd4j/include/execution/LaunchContext.h b/libnd4j/include/execution/LaunchContext.h index 23165fa0e..5fae2162c 100644 --- a/libnd4j/include/execution/LaunchContext.h +++ b/libnd4j/include/execution/LaunchContext.h @@ -37,6 +37,7 @@ #include #include #include +#include @@ -97,9 +98,12 @@ class ND4J_EXPORT LaunchContext { int getDeviceID() const {return _deviceID;} void setDeviceID(int deviceID) { _deviceID = deviceID; } + sd::ErrorReference* errorReference(); static bool isInitialized(); static void releaseBuffers(); + + static LaunchContext* defaultContext(); diff --git a/libnd4j/include/execution/cpu/ContextBuffers.cpp b/libnd4j/include/execution/cpu/ContextBuffers.cpp index 3bf0a01eb..0038990c2 100644 --- a/libnd4j/include/execution/cpu/ContextBuffers.cpp +++ b/libnd4j/include/execution/cpu/ContextBuffers.cpp @@ -99,4 +99,8 @@ namespace nd4j { ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { return *this; } + + sd::ErrorReference* ContextBuffers::errorReference() { + return &_errorReference; + } } \ No newline at end of file diff --git a/libnd4j/include/execution/cpu/LaunchContext.cpp b/libnd4j/include/execution/cpu/LaunchContext.cpp index 3ee460350..60e29c7ca 100644 --- a/libnd4j/include/execution/cpu/LaunchContext.cpp +++ b/libnd4j/include/execution/cpu/LaunchContext.cpp @@ -23,7 +23,11 @@ #include #include +#ifdef IOS_BUILD nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); +#else +thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); +#endif namespace nd4j { @@ -65,4 +69,8 @@ namespace nd4j { void LaunchContext::releaseBuffers() { // } + + sd::ErrorReference* LaunchContext::errorReference() { + return contextBuffers.errorReference(); + } } \ No newline at end of file diff --git a/libnd4j/include/execution/cuda/ContextBuffers.cu b/libnd4j/include/execution/cuda/ContextBuffers.cu index 84db0c284..895bb6623 100644 --- a/libnd4j/include/execution/cuda/ContextBuffers.cu +++ b/libnd4j/include/execution/cuda/ContextBuffers.cu @@ -220,5 +220,9 @@ namespace nd4j { bool ContextBuffers::isInitialized() { return _initialized; } + + sd::ErrorReference* ContextBuffers::errorReference() { + return &_errorReference; + } } diff --git a/libnd4j/include/execution/cuda/LaunchContext.cu b/libnd4j/include/execution/cuda/LaunchContext.cu index 1292f756c..9d9f2c506 100644 --- a/libnd4j/include/execution/cuda/LaunchContext.cu +++ b/libnd4j/include/execution/cuda/LaunchContext.cu @@ -168,4 +168,8 @@ LaunchContext::LaunchContext() { bool LaunchContext::isInitialized() { return contextBuffers.isInitialized(); } + + sd::ErrorReference* LaunchContext::errorReference() { + return contextBuffers.errorReference(); + } } \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/misc/FunctionUtil.java b/libnd4j/include/execution/impl/ErrorReference.cpp similarity index 54% rename from deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/misc/FunctionUtil.java rename to libnd4j/include/execution/impl/ErrorReference.cpp index d1b87c443..7b3409aa1 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/misc/FunctionUtil.java +++ b/libnd4j/include/execution/impl/ErrorReference.cpp @@ -14,32 +14,33 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.ui.play.misc; +// +// @author raver119@gmail.com +// -import play.libs.F; -import play.mvc.Result; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.function.Supplier; +#include -/** - * Utility methods for Routing - * - * @author Alex Black - */ -public class FunctionUtil { - - public static F.Function0 function0(Supplier supplier) { - return supplier::get; +namespace sd { + int ErrorReference::errorCode() { + return _errorCode; } - public static F.Function function(Function function) { - return function::apply; + const char* ErrorReference::errorMessage() { + // since we're fetching error message - error code will be assumed consumed & nullified + _errorCode = 0; + return _errorMessage.c_str(); } - public static F.Function2 biFunction(BiFunction function) { - return function::apply; + void ErrorReference::setErrorCode(int errorCode) { + _errorCode = errorCode; } + void ErrorReference::setErrorMessage(std::string message) { + _errorMessage = message; + } + + void ErrorReference::setErrorMessage(const char* message) { + _errorMessage = std::string(message); + } } diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h index bda04414f..d04d3315d 100644 --- a/libnd4j/include/helpers/Loops.h +++ b/libnd4j/include/helpers/Loops.h @@ -80,14 +80,14 @@ namespace nd4j { }; - template + template class ND4J_EXPORT IndexReductionLoops { private: public: - static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams); + static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams); template - static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams); + static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams); }; diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp index 33e230bd5..0a096b65f 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp @@ -24,10 +24,10 @@ using namespace simdOps; ////////////////////////////////////////////////////////////////////////////// -template +template template -void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, - Nd4jLong* z, Nd4jLong* zShapeInfo, +void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, + Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) { @@ -62,7 +62,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -80,7 +80,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i * zEws] = indexValue.index; + z[i * zEws] = (Z) indexValue.index; } } break; @@ -98,7 +98,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -122,7 +122,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -148,7 +148,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -176,7 +176,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -206,7 +206,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -227,7 +227,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); - z[zOffset] = indexValue.index; + z[zOffset] = (Z) indexValue.index; } } break; @@ -248,7 +248,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i * zEws] = indexValue.index; + z[i * zEws] = (Z) indexValue.index; } } break; @@ -272,18 +272,19 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); - z[zOffset] = indexValue.index; + z[zOffset] = (Z) indexValue.index; } } } } -template -void nd4j::IndexReductionLoops::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) { +template +void nd4j::IndexReductionLoops::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* vz, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) { auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); auto extraParams = reinterpret_cast(vextraParams); - DISPATCH_BY_OPNUM_T(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS); + DISPATCH_BY_OPNUM_TT(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS); } -BUILD_SINGLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES, INDEXING_TYPES); \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index dda709545..32394f705 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -218,6 +218,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + const int deviceId = AffinityManager::currentDeviceId(); + const int major = Environment::getInstance()->capabilities()[deviceId].first(); + NDArray::prepareSpecialUse({pC}, {pA, pB}); // choose appropriate cuda gemm api depending on data types @@ -228,15 +231,15 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou float alphaF(alpha), betaF(beta); status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); } - else if(ABC && aType == DataType::HALF) { + else if(ABC && aType == DataType::HALF && major >= 6) { float16 alphaH(alpha), betaH(beta); status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); } - else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32) { + else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6) { float alphaF(alpha), betaF(beta); status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); } - else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32) { + else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6) { float alphaF(alpha), betaF(beta); status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); } diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index 4b11cfc0d..fb5acf19b 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -71,12 +71,25 @@ inline pairwise::BoolOps fromBroadcastToPairwiseBool(broadcast::BoolOps op) { case broadcast::And: return pairwise::And; case broadcast::Or: return pairwise::Or; case broadcast::Xor: return pairwise::Xor; - case broadcast::Not: return pairwise::Not; + case broadcast::Not: return pairwise::Not; default: throw std::runtime_error("fromBroadcastToPairwiseBool: Not convertible operation"); } } + inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) { + switch (op) { + case broadcast::IntOps::IntAnd: return pairwise::IntOps::IntAnd; + case broadcast::IntOps::IntOr: return pairwise::IntOps::IntOr; + case broadcast::IntOps::IntXor: return pairwise::IntOps::IntXor; + case broadcast::IntOps::ShiftLeft: return pairwise::IntOps::ShiftLeft; + case broadcast::IntOps::ShiftRight: return pairwise::IntOps::ShiftRight; + case broadcast::IntOps::CyclicShiftLeft: return pairwise::IntOps::CyclicShiftLeft; + case broadcast::IntOps::CyclicShiftRight: return pairwise::IntOps::CyclicShiftRight; + default: + throw std::runtime_error("fromBroadcastToPairwiseInt: Not convertible operation"); + } + } } #endif //DEV_TESTS_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file diff --git a/libnd4j/include/loops/broadcasting_int.h b/libnd4j/include/loops/broadcasting_int.h new file mode 100644 index 000000000..84bc0f949 --- /dev/null +++ b/libnd4j/include/loops/broadcasting_int.h @@ -0,0 +1,164 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +/* + * broadcasting.h + * + * Created on: Dec 28, 2015 + * Author: agibsonccc + */ + +#ifndef BROADCASTING_INT_H_ +#define BROADCASTING_INT_H_ +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#include +#endif +#ifdef __JNI__ +#include +#endif + +#include + +#include "legacy_ops.h" + +namespace functions { + namespace broadcast { + +/** + * Broadcast operation + * for broadcasting a smaller tensor + * along long a bigger one. + */ + template + class BroadcastInt { + public: + +#ifdef __CUDACC__ + + template + static __device__ void transformCuda( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformInverseCuda( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + +#endif + + static void exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + + static void execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + + template + static void execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + }; + } +} + +#endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/cpu/broadcasting_int.cpp b/libnd4j/include/loops/cpu/broadcasting_int.cpp new file mode 100644 index 000000000..c092da50b --- /dev/null +++ b/libnd4j/include/loops/cpu/broadcasting_int.cpp @@ -0,0 +1,464 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include + +using namespace simdOps; + +namespace functions { + namespace broadcast { + + template + void BroadcastInt::exec(const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, + xShapeInfo, + y, + yShapeInfo, + z, + zShapeInfo, + dimension, + dimensionLength, + xTadShapeInfo, + xTadOffset, + zTadShapeInfo, + zTadOffset), BROADCAST_INT_OPS); + } + + template + void BroadcastInt::execInverse(const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x, + xShapeInfo, + y, + yShapeInfo, + z, + zShapeInfo, + dimension, + dimensionLength, + xTadShapeInfo, + xTadOffset, + zTadShapeInfo, + zTadOffset), BROADCAST_INT_OPS); + } + + template + template + void BroadcastInt::exec(void *vx, + Nd4jLong *xShapeInfo, + void *vy, + Nd4jLong *yShapeInfo, + void *vz, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + //int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = nd4j::math::nd4j_max(1, tadsPerThread); + threads = nd4j::math::nd4j_min(threads, omp_get_max_threads()); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); + + if (kindOfLoop == nd4j::LoopKind::EWS1) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f]); + } + } + else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + } + } + else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + // TODO: cover this codebranch with tests + // all this stuff already happens within thread + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset]); + } + } + } + else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset]); + } + } + } + else { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + } + } + } + } + + + template + template + void BroadcastInt::execInverse(void *vx, + Nd4jLong *xShapeInfo, + void *vy, + Nd4jLong *yShapeInfo, + void *vz, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *yTadShapeInfo, + Nd4jLong *yTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + //int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = nd4j::math::nd4j_max(1, tadsPerThread); + threads = nd4j::math::nd4j_min(threads, omp_get_max_threads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == nd4j::LoopKind::EWS1) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f]); + } + } + else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + } + } + else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + // TODO: cover this codebranch with tests + // all this stuff already happens within thread + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset]); + } + } + } + else { + + uint xShapeInfoCast[MAX_RANK]; + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + } + } + } + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/indexreduce.cpp b/libnd4j/include/loops/cpu/indexreduce.cpp index 951ac287b..5a7beee24 100644 --- a/libnd4j/include/loops/cpu/indexreduce.cpp +++ b/libnd4j/include/loops/cpu/indexreduce.cpp @@ -31,26 +31,27 @@ namespace functions { namespace indexreduce { //////////////////////////////////////////////////////////////////////// -template Nd4jLong IndexReduce::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS); +template +Nd4jLong IndexReduce::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// -template -void IndexReduce::exec(const int opNum, +template +void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, - Nd4jLong *z, Nd4jLong *zShapeInfo, + void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { -DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); +DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// -template +template template -Nd4jLong IndexReduce::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { +Nd4jLong IndexReduce::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { auto x = reinterpret_cast(vx); auto extraParams = reinterpret_cast(vextraParams); @@ -105,15 +106,16 @@ Nd4jLong IndexReduce::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextra //////////////////////////////////////////////////////////////////////// -template +template template -void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, +void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, - Nd4jLong *z, Nd4jLong *zShapeInfo, + void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); auto extraParams = reinterpret_cast(vextraParams); const Nd4jLong zLen = shape::length(zShapeInfo); @@ -124,12 +126,12 @@ void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, const auto indexValue = OpType::startingIndexValue(x); PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold()) for (uint i = 0; i < zLen; i++) - z[i] = indexValue.index;; + z[i] = (Z) indexValue.index;; return; } if(shape::isScalar(zShapeInfo)) { - z[0] = execScalar(x,xShapeInfo,extraParams); + z[0] = (Z) execScalar(x,xShapeInfo,extraParams); return; } @@ -146,11 +148,11 @@ void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, tadOffsets = tadPack.primaryOffsets(); } - nd4j::IndexReductionLoops::template loopIndexReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); + nd4j::IndexReductionLoops::template loopIndexReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); } -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/pairwise_int.cpp b/libnd4j/include/loops/cpu/pairwise_int.cpp new file mode 100644 index 000000000..b356adcc2 --- /dev/null +++ b/libnd4j/include/loops/cpu/pairwise_int.cpp @@ -0,0 +1,309 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +using namespace simdOps; + +namespace functions { + namespace pairwise_transforms { + + template + void PairWiseIntTransform::exec( + const int opNum, + void *x, + Nd4jLong xEws, + void *y, + Nd4jLong yEws, + void *z, + Nd4jLong zEws, + void *extraParams, + Nd4jLong n) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, + xEws, + y, + yEws, + z, + zEws, + extraParams, + n), PAIRWISE_INT_OPS); + }; + + + + template + template + void PairWiseIntTransform::exec(void *vx, + Nd4jLong xEws, + void *vy, + Nd4jLong yEws, + void *vz, + Nd4jLong zEws, + void *vextraParams, + const Nd4jLong n) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + nd4j::OmpLaunchHelper info(n); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + Nd4jLong threadOffset = info.getThreadOffset(threadNum); + auto xi = x + threadOffset; + auto yi = y + threadOffset; + auto zi = z + threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) + zi[i] = OpType::op(xi[i], yi[i], extraParams); + } + } + else { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + Nd4jLong threadOffset = info.getThreadOffset(threadNum); + auto xi = x + xEws*threadOffset; + auto yi = y + yEws*threadOffset; + auto zi = z + zEws*threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) + zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams); + } + } + } + + template + void PairWiseIntTransform::exec( + const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + void *extraParams) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, + xShapeInfo, + y, + yShapeInfo, + z, + zShapeInfo, + extraParams), + PAIRWISE_INT_OPS); + }; + + + template + template + void PairWiseIntTransform::exec(void *vx, Nd4jLong* xShapeInfo, + void *vy, Nd4jLong* yShapeInfo, + void *vz, Nd4jLong* zShapeInfo, + void *vextraParams) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + nd4j::OmpLaunchHelper info(n); + + if (shape::isScalar(yShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for(Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + } + } + } + else { + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for(Nd4jLong i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + } + } + } + return; + } + + const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n); + } + else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); + } + else { + + if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + } + } + } + else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + } + } + } + else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + } + } + } + else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + } + } + } + else { + + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } + } + } + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); + } +} diff --git a/libnd4j/include/loops/cpu/scalar_int.cpp b/libnd4j/include/loops/cpu/scalar_int.cpp new file mode 100644 index 000000000..9920cc836 --- /dev/null +++ b/libnd4j/include/loops/cpu/scalar_int.cpp @@ -0,0 +1,255 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../scalar_int.h" +#include +#include +#include + +#include "../legacy_ops.h" + +using namespace simdOps; + +namespace functions { + namespace scalar { + + + template + template + void ScalarIntTransform::transform(void *vx, Nd4jLong *xShapeInfo, + void *vextraParams, + void *vz, Nd4jLong *zShapeInfo, + void *vscalars, + int *dimension, int dimensionLength, + Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, + Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { + + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + // tad preparation + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != nd4j::LoopKind::EWS1 && kindOfLoop != nd4j::LoopKind::EWSNONZERO) { + printf("ScalarIntTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); + return; + } + + int num_threads = nd4j::math::nd4j_min(numTads, omp_get_max_threads()); + + if (kindOfLoop == nd4j::LoopKind::EWS1) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) + for (unsigned int r = 0; r < numTads; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + } + } + else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO + PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) + for (unsigned int r = 0; r < numTads; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + } + } + } + + template + void ScalarIntTransform::transform(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *z, + Nd4jLong *zShapeInfo, + void *scalars, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffsets, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffsets) { + DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_INT_OPS); + } + + + template + void ScalarIntTransform::transform(const int opNum, + void *x, + Nd4jLong xEws, + void *z, + Nd4jLong zEws, + void *scalar, + void *extraParams, + const Nd4jLong n) { + DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_INT_OPS); + } + + template + void ScalarIntTransform::transform(const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + void *scalar, + void *extraParams) { + DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_INT_OPS); + } + + template + template + void ScalarIntTransform::transform(void *vx, + Nd4jLong *xShapeInfo, + void *vz, + Nd4jLong *zShapeInfo, + void *vscalar, + void *vextraParams) { + + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto len = shape::length(xShapeInfo); + + // nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws); + + nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len); + return; + } + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + nd4j::OmpLaunchHelper info(len); + + if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + } + } + } + else { + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + } + } + } + } + + + template + template + void ScalarIntTransform::transform(void *vx, + Nd4jLong xEws, + void *vz, + Nd4jLong zEws, + void *vscalar, + void *vextraParams, + const Nd4jLong len) { + + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + nd4j::OmpLaunchHelper info(len); + + if (xEws == 1 && zEws == 1) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto xi = x + threadOffset; + auto zi = z + threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) + zi[i] = OpType::op(xi[i], scalar, extraParams); + } + } + else { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto xi = x + xEws * threadOffset; + auto zi = z + zEws * threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) + zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams); + } + } + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); + +} +} diff --git a/libnd4j/include/loops/cuda/broadcasting_int.cu b/libnd4j/include/loops/cuda/broadcasting_int.cu new file mode 100644 index 000000000..38193f35d --- /dev/null +++ b/libnd4j/include/loops/cuda/broadcasting_int.cu @@ -0,0 +1,291 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace simdOps; + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void broadcastIntSimple( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + functions::broadcast::BroadcastInt::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); +} + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void broadcastBoolInverseSimple( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + functions::broadcast::BroadcastInt::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); +} + +namespace functions { + namespace broadcast { +////////////////////////////////////////////////////////////////////////// + template + template + __host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + } + +////////////////////////////////////////////////////////////////////////// + template + __host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) + } + +////////////////////////////////////////////////////////////////////////// + template + template + __host__ void BroadcastInt::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + } + +////////////////////////////////////////////////////////////////////////// + template + __host__ void BroadcastInt::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) + } + +////////////////////////////////////////////////////////////////////////// + template + template + __device__ void BroadcastInt::transformInverseCuda( + void *vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rZ = z + tadOffsetsZ[r]; + auto rY = y + tadOffsets[r]; + + if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { + + for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + } + else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength); + auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + } + } + } + } + +////////////////////////////////////////////////////////////////////////// + template + template + __device__ void BroadcastInt::transformCuda( + void *vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + __shared__ X *rZ; + __shared__ X *rX; + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + + if (threadIdx.x == 0) { + rZ = z + tadOffsetsZ[r]; + rX = x + tadOffsets[r]; + } + __syncthreads(); + + + if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { + + for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + } + else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + } + } + } + } + + + template + void BroadcastInt::exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + void BroadcastInt::execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastInt::exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastInt::execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index 18e5b1432..5f0cf07ae 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -29,37 +29,37 @@ using namespace simdOps; -template +template static __global__ void simpleIndexReduceGeneric(const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, - Nd4jLong *result, + void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { - functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,resultShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); + functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,resultShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); } namespace functions { namespace indexreduce { - template - _CUDA_H void IndexReduce::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, + template + _CUDA_H void IndexReduce::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, - Nd4jLong *result, Nd4jLong *resultShapeInfo, + void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { - simpleIndexReduceGeneric<<>>(opNum, + simpleIndexReduceGeneric<<>>(opNum, dx, xShapeInfo, xRank, extraParams, result, resultShapeInfo, 0, @@ -67,13 +67,11 @@ namespace functions { 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - - nd4j::DebugHelper::checkErrorCode(stream, "execIndexReduceScalar(...) failed"); } - template - _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { - simpleIndexReduceGeneric<<>>( + template + _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { + simpleIndexReduceGeneric<<>>( opNum, dx, xShapeInfo, xRank, @@ -83,8 +81,6 @@ namespace functions { dimension, dimensionLength, 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - - DEBUG_KERNEL(stream, opNum); } // This is the un-specialized struct. Note that we prevent instantiation of this @@ -122,14 +118,14 @@ namespace functions { } }; - template + template template - __device__ void IndexReduce::aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { + __device__ void IndexReduce::aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { // start the shared memory loop on the next power of 2 less // than the block size. If block size is not a power of 2, // accumulate the intermediate sums in the remainder range. - auto extraParams = static_cast(vextraParams); - IndexValue *sPartials = *sPartialsRef; + auto extraParams = static_cast(vextraParams); + IndexValue *sPartials = *sPartialsRef; Nd4jLong floorPow2 = blockDim.x; if (floorPow2 & (floorPow2 - 1)) { @@ -138,8 +134,8 @@ namespace functions { } if (tid >= floorPow2) { - IndexValue prev = sPartials[tid - floorPow2]; - IndexValue curr = sPartials[tid]; + IndexValue prev = sPartials[tid - floorPow2]; + IndexValue curr = sPartials[tid]; sPartials[tid - floorPow2] = OpType::update(prev,curr,extraParams); } __syncthreads(); @@ -147,21 +143,21 @@ namespace functions { for (int activeThreads = floorPow2 >> 1;activeThreads; activeThreads >>= 1) { if (tid < activeThreads && tid + activeThreads < numElements) { - IndexValue curr = sPartials[tid]; - IndexValue next = sPartials[tid + activeThreads]; + IndexValue curr = sPartials[tid]; + IndexValue next = sPartials[tid + activeThreads]; sPartials[tid] = OpType::update(curr,next,extraParams); } __syncthreads(); } } - template - __device__ void IndexReduce::transform( + template + __device__ void IndexReduce::transform( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, - Nd4jLong *result, + void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, @@ -170,15 +166,15 @@ namespace functions { void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, result, resultShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); + DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, resultShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); } - template + template template - __device__ void IndexReduce::transform(void *vdx, Nd4jLong *xShapeInfo, + __device__ void IndexReduce::transform(void *vdx, Nd4jLong *xShapeInfo, void *vextraParams, - Nd4jLong *result, Nd4jLong *resultShapeInfo, + void *vresult, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *vreductionBuffer, @@ -186,18 +182,19 @@ namespace functions { /**int * Gpu information for the problem */ - auto dx = static_cast(vdx); - auto extraParams = static_cast(vextraParams); - auto reductionBuffer = static_cast(vreductionBuffer); + auto dx = reinterpret_cast(vdx); + auto result = reinterpret_cast(vresult); + auto extraParams = static_cast(vextraParams); + auto reductionBuffer = static_cast(vreductionBuffer); auto order = shape::order(xShapeInfo); int tid = blockIdx.x * blockDim.x + threadIdx.x; __shared__ volatile int resultScalar; //shared memory space for storing intermediate results - __shared__ IndexValue* sPartials; + __shared__ IndexValue* sPartials; if(threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast*>(shmem); + sPartials = reinterpret_cast*>(shmem); } __syncthreads(); @@ -210,7 +207,7 @@ namespace functions { //only compute the tad indexes once - IndexValue reduction = OpType::startingIndexValue(dx); + IndexValue reduction = OpType::startingIndexValue(dx); if (threadIdx.x == 0) { if (resultShapeInfo != nullptr) @@ -255,7 +252,7 @@ namespace functions { for(int i = threadIdx.x;i < tadLength; i += blockDim.x) { auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); - IndexValue comp {dx[xOffset], i}; + IndexValue comp {dx[xOffset], i}; sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); } @@ -264,7 +261,7 @@ namespace functions { __syncthreads(); if (threadIdx.x == 0) { - result[r] = sPartials[threadIdx.x].index; + result[r] = (Z) sPartials[threadIdx.x].index; } __syncthreads(); } @@ -276,7 +273,7 @@ namespace functions { sPartials[threadIdx.x] = OpType::startingIndexValue(dx); for (int x = threadIdx.x; x < tadLength; x+= blockDim.x) { - IndexValue comp {dx[tadOffsetForBlock + x * tadEWS], x}; + IndexValue comp {dx[tadOffsetForBlock + x * tadEWS], x}; sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); } @@ -285,7 +282,7 @@ namespace functions { __syncthreads(); if (threadIdx.x == 0) { - result[i] = sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); + result[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); } __syncthreads(); } @@ -296,14 +293,14 @@ namespace functions { if(xElementWiseStride >= 1 && order == 'c') { for(Nd4jLong i = tid;i < n; i += (blockDim.x * gridDim.x)) { - IndexValue indexVal = {dx[i * xElementWiseStride], i}; + IndexValue indexVal = {dx[i * xElementWiseStride], i}; reduction = OpType::update(reduction, indexVal, extraParams); } } else { for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) { auto offset = shape::getIndexOffset(i, xShapeInfo, n); - IndexValue indexVal = {dx[offset], i}; + IndexValue indexVal = {dx[offset], i}; reduction = OpType::update(reduction, indexVal, extraParams); } } @@ -320,7 +317,7 @@ namespace functions { unsigned int *tc = (unsigned int *) reductionBuffer; tid = threadIdx.x; if (threadIdx.x == 0) { - auto pBuffer = reinterpret_cast *>(reductionBuffer); + auto pBuffer = reinterpret_cast *>(reductionBuffer); pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index}; } __threadfence(); @@ -335,7 +332,7 @@ namespace functions { if (amLast) { tc[16384] = 0; - IndexValue *pBuffer = (IndexValue *) reductionBuffer; + IndexValue *pBuffer = (IndexValue *) reductionBuffer; sPartials[threadIdx.x] = OpType::startingIndexValue(dx); @@ -348,14 +345,14 @@ namespace functions { __syncthreads(); if (tid == 0) { - result[0] = sPartials[0].index; + result[0] = (Z) sPartials[0].index; } } } else { if (tid == 0) { auto tc = reinterpret_cast(reductionBuffer); tc[16384] = 0; - result[0] = sPartials[0].index; + result[0] = (Z) sPartials[0].index; } } @@ -365,30 +362,30 @@ namespace functions { - template - Nd4jLong IndexReduce::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { + template + Nd4jLong IndexReduce::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { return 0; } - template - void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + template + void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { } - template + template template - Nd4jLong IndexReduce:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) { + Nd4jLong IndexReduce:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) { return 0; } - template + template template - _CUDA_H void IndexReduce::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + _CUDA_H void IndexReduce::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES); + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); } } diff --git a/libnd4j/include/loops/cuda/pairwise_int.cu b/libnd4j/include/loops/cuda/pairwise_int.cu new file mode 100644 index 000000000..5cc12846c --- /dev/null +++ b/libnd4j/include/loops/cuda/pairwise_int.cu @@ -0,0 +1,173 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018 + +#ifndef PAIRWISE_INT_CU +#define PAIRWISE_INT_CU + + +#include "../pairwise_int.h" + + +using namespace simdOps; + +//////////////////////////////////////////////////////////////////////////////// +template +__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void *vextraParams) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } + else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); + auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } + else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); + auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); + auto zOffset = shape::getIndexOffset(i, zShapeInfo, len); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } +} + + +namespace functions { +namespace pairwise_transforms { + +//////////////////////////////////////////////////////////////////////////////// +template +template +void _CUDA_H PairWiseIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, + void *vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void *vextraParams){ + + pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +} + + +//////////////////////////////////////////////////////////////////////////////// +template +void PairWiseIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) { + auto xType = nd4j::DataTypeUtils::fromT(); + + DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS); +} + + + template + void PairWiseIntTransform::exec( + const int opNum, + void *dx, + Nd4jLong *xShapeBuffer, + void *y, + Nd4jLong *yShapeBuffer, + void *result, + Nd4jLong *resultShapeBuffer, + void *extraParams) { + + } + + template + void PairWiseIntTransform::exec( + const int opNum, + void *dx, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *result, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong n) { + + } + + + template + template + void PairWiseIntTransform::exec( + void *vx, + Nd4jLong* xShapeBuffer, + void *vy, + Nd4jLong* yShapeBuffer, + void *vresult, + Nd4jLong* resultShapeBuffer, + void *vextraParams) { + + } + + template + template + void PairWiseIntTransform::exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong n) { + + } + + + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); +} +} + +#endif // PAIRWISE_INT_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar_int.cu b/libnd4j/include/loops/cuda/scalar_int.cu new file mode 100644 index 000000000..48f141525 --- /dev/null +++ b/libnd4j/include/loops/cuda/scalar_int.cu @@ -0,0 +1,269 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018 +// @author raver119@gmail.com +// + +#include "../scalar_int.h" +#include +#include + +#include "../legacy_ops.h" + +using namespace simdOps; + +//////////////////////////////////////////////////////////////////////// +template +__global__ void scalarAlongDimension(void *x, Nd4jLong *xShapeInfo, + void *extraParams, + void *z, Nd4jLong *zShapeInfo, + void *scalars, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + functions::scalar::ScalarIntTransform::template transformCuda(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +} + + +//////////////////////////////////////////////////////////////////////// +template +__global__ void scalarSimpleShaped(void* x, void *y, Nd4jLong *xShapeInfo, void *params, void *z, Nd4jLong *zShapeInfo, int *allocationBuffer) { + + functions::scalar::ScalarIntTransform::template transformCuda(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); +} + + + + + +// *********************************************************************// +// *********************************************************************// +namespace functions { +namespace scalar { + +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void ScalarIntTransform::transformCuda(void* vscalar, + void *vy, Nd4jLong *yShapeInfo, + void *vparams, + void *vz, Nd4jLong *zShapeInfo, + int *allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + auto yRank = shape::rank(yShapeInfo); + auto yEWS = shape::elementWiseStride(yShapeInfo); + auto yShape = shape::shapeOf(yShapeInfo); + auto yStride = shape::stride(yShapeInfo); + + auto zRank = shape::rank(zShapeInfo); + auto zEWS = shape::elementWiseStride(zShapeInfo); + auto zShape = shape::shapeOf(zShapeInfo); + auto zStride = shape::stride(zShapeInfo); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int len; + if(threadIdx.x == 0) + len = shape::length(yShapeInfo); + __syncthreads(); + + if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { + transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); + } + else { + for (Nd4jLong i = tid; i < len; i+= totalThreads) + z[shape::getIndexOffset(i, zShapeInfo, len)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo, len)], scalar, params); + } +} + +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void ScalarIntTransform::transformCuda(Nd4jLong len, + void* vx, + void *vy, Nd4jLong yEWS, + void *vparams, + void *vz, Nd4jLong zEWS, + int *allocationBuffer) { + + auto x = reinterpret_cast(vx)[0]; + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong i = tid; + if(yEWS == 1 && zEWS == 1) { + for (; i < len; i += totalThreads) + z[i] = OpType::op(y[i], x, params); + } + else { + for (; i < len; i += totalThreads) + z[i * zEWS] = OpType::op(y[i * yEWS], x, params); + } +} + + +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void ScalarIntTransform::transformCuda(void *vx, Nd4jLong *xShapeInfo, + void *vextraParams, + void *vz, Nd4jLong *zShapeInfo, + void *vscalars, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto scalars = reinterpret_cast(vscalars); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); + auto numTads =shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X *oZ = z + tadOffsetsZ[r]; + X *oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); + } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X *oZ = z + tadOffsetsZ[r]; + X *oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ, tadLength)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo, tadLength)], s, extraParams); + } + } +} + + +//////////////////////////////////////////////////////////////////////// +template +template +_CUDA_H void ScalarIntTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, + void *x, Nd4jLong *xShapeInfo, + void *z, Nd4jLong *zShapeInfo, + void *scalars, + void *extraParams, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +} + +//////////////////////////////////////////////////////////////////////// +template +template +void _CUDA_H ScalarIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, + void *vx, Nd4jLong *xShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void* vscalar, + void *vextraParams, int *allocPointer){ + + scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); +} + +//////////////////////////////////////////////////////////////////////// +template +void ScalarIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, + int opNum, + void *vx, Nd4jLong *xShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void* vscalar, + void *vextraParams) { + + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, vextraParams, nullptr), SCALAR_INT_OPS); +} + +//////////////////////////////////////////////////////////////////////// +template +void ScalarIntTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_T(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_INT_OPS); +} + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); + + + template + template + void ScalarIntTransform::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarIntTransform::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarIntTransform::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + template + void ScalarIntTransform::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } + + template + template + void ScalarIntTransform::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + + template + template + void ScalarIntTransform::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } +} +} + diff --git a/libnd4j/include/loops/indexreduce.h b/libnd4j/include/loops/indexreduce.h index 40f98c692..792ed16a9 100755 --- a/libnd4j/include/loops/indexreduce.h +++ b/libnd4j/include/loops/indexreduce.h @@ -52,35 +52,35 @@ namespace functions { namespace indexreduce { - template + template class IndexReduce { public: #ifdef __CUDACC__ - static __device__ void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int *dimension,int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); + static __device__ void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int *dimension,int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); template - static __device__ void aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements,void *extraParams); + static __device__ void aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements,void *extraParams); template - static __device__ void transform(void *dx, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); + static __device__ void transform(void *dx, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); - static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); + static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); - static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); + static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); #endif static Nd4jLong execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams); - static void exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); + static void exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); template static _CUDA_H Nd4jLong execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams); template - static _CUDA_H void exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); + static _CUDA_H void exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); }; } } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index b3096ac0e..b0d891287 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -29,6 +29,16 @@ (4, aggregateOps::CBOW) ,\ (5, aggregateOps::GEMM) +#define BROADCAST_INT_OPS \ + (0, ShiftLeft), \ + (1, ShiftRight), \ + (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), \ + (4, IntAnd), \ + (5, IntOr), \ + (6, IntXor) + + #define BROADCAST_BOOL_OPS \ (0, EqualTo),\ (1, GreaterThan),\ @@ -171,6 +181,14 @@ (0, SummaryStatsVariance), \ (1, SummaryStatsStandardDeviation) +#define SCALAR_INT_OPS \ + (0, ShiftLeft), \ + (1, ShiftRight), \ + (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), \ + (4, IntAnd), \ + (5, IntOr), \ + (6, IntXor) #define SCALAR_BOOL_OPS \ (0, EqualTo),\ @@ -300,6 +318,15 @@ (13, ExponentialDistribution),\ (14, ExponentialDistributionInv) +#define PAIRWISE_INT_OPS \ + (0, ShiftLeft), \ + (1, ShiftRight), \ + (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), \ + (4, IntAnd), \ + (5, IntOr), \ + (6, IntXor) + #define PAIRWISE_BOOL_OPS \ (0, EqualTo),\ (1, GreaterThan),\ diff --git a/libnd4j/include/loops/pairwise_int.h b/libnd4j/include/loops/pairwise_int.h new file mode 100644 index 000000000..14d273285 --- /dev/null +++ b/libnd4j/include/loops/pairwise_int.h @@ -0,0 +1,119 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +/* + * pairwise_transform.h + * + * Created on: Dec 28, 2015 + * Author: agibsonccc + */ + +#ifndef PAIRWISE_INT_H_ +#define PAIRWISE_INT_H_ +#ifdef _OPENMP +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#include +#endif + +#ifndef _OPENMP +#define omp_get_thread_num() 0 +#define omp_get_max_threads() 1 +#endif + + +#include "legacy_ops.h" + +using namespace simdOps; + +namespace functions { + namespace pairwise_transforms { + +/** + * Transforms involving 2 arrays + */ + template + class PairWiseIntTransform { + public: + +#ifdef __CUDACC__ + + template + static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams); + + static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams); + + +#endif + public: + + static void exec( + const int opNum, + void *dx, + Nd4jLong *xShapeBuffer, + void *y, + Nd4jLong *yShapeBuffer, + void *result, + Nd4jLong *resultShapeBuffer, + void *extraParams); + + static void exec( + const int opNum, + void *dx, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *result, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong n); + + + template + static void exec( + void *vx, + Nd4jLong* xShapeBuffer, + void *vy, + Nd4jLong* yShapeBuffer, + void *vresult, + Nd4jLong* resultShapeBuffer, + void *vextraParams); + + template + static void exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong n); + }; + } +} + +#endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/scalar_int.h b/libnd4j/include/loops/scalar_int.h new file mode 100644 index 000000000..f873d5419 --- /dev/null +++ b/libnd4j/include/loops/scalar_int.h @@ -0,0 +1,142 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +/* + * scalar.h + * + * Created on: Dec 28, 2015 + * Author: agibsonccc + */ + +#ifndef SCALAR_INT_H_ +#define SCALAR_INT_H_ +#include + +#ifdef __JNI__ +#include +#endif +#include +#include +#include +#include "helpers/logger.h" +#include +#include + +#ifdef __CUDACC__ +#include +#include +#include +#endif + +#include "legacy_ops.h" + +namespace functions { + namespace scalar { +/** + * Apply a scalar + * operation to an array + */ + template + class ScalarIntTransform { + + public: + +#ifdef __CUDACC__ + + template + __device__ + static void transformCuda(void* scalar, void *vy, Nd4jLong *shapeInfo, void *vparams, void *vresult, Nd4jLong *resultShapeInfo, int *allocationBuffer); + + template + __device__ + static void transformCuda(Nd4jLong n, void* vx, void *vy, Nd4jLong yEWS, void *vparams, void *vz, Nd4jLong zEWS, int *allocationBuffer); + + template + __device__ + static void transformCuda(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vscalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + __host__ + static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + __host__ + static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void* vscalar, void *vextraParams, int *allocPointer); + + __host__ + static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void* scalar, void *extraParams); + + __host__ + static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + +/* +#include "cuda/scalar_temp.cu" +*/ +#endif + template + static void transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static void transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams); + + static void transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n); + + + + + /* + * ScalarOp along dimension + */ + + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams); + + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n); + }; + } +} + + +#endif /* SCALAR_H_ */ diff --git a/libnd4j/include/op_enums.h b/libnd4j/include/op_enums.h index bf9d4c72f..8a100153f 100644 --- a/libnd4j/include/op_enums.h +++ b/libnd4j/include/op_enums.h @@ -63,6 +63,10 @@ namespace nd4j { enum BoolOps { BUILD_ENUMERATION(PAIRWISE_BOOL_OPS) }; + + enum IntOps { + BUILD_ENUMERATION(PAIRWISE_INT_OPS) + }; } namespace scalar { @@ -73,6 +77,10 @@ namespace nd4j { enum BoolOps { BUILD_ENUMERATION(SCALAR_BOOL_OPS) }; + + enum IntOps { + BUILD_ENUMERATION(SCALAR_INT_OPS) + }; } namespace reduce { @@ -113,6 +121,10 @@ namespace nd4j { enum BoolOps { BUILD_ENUMERATION(BROADCAST_BOOL_OPS) }; + + enum IntOps { + BUILD_ENUMERATION(BROADCAST_INT_OPS) + }; } namespace variance { diff --git a/libnd4j/include/helpers/impl/ProviderRNG.cpp b/libnd4j/include/ops/BroadcastIntOpsTuple.h similarity index 50% rename from libnd4j/include/helpers/impl/ProviderRNG.cpp rename to libnd4j/include/ops/BroadcastIntOpsTuple.h index 216aa3a32..df40907a9 100644 --- a/libnd4j/include/helpers/impl/ProviderRNG.cpp +++ b/libnd4j/include/ops/BroadcastIntOpsTuple.h @@ -15,37 +15,35 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 27.01.2018 +// @author raver119@gmail.com // -#include -#include +#ifndef DEV_TESTS_BROADCASTINTOPSTUPLE_H +#define DEV_TESTS_BROADCASTINTOPSTUPLE_H +#include namespace nd4j { - -ProviderRNG::ProviderRNG() { + class BroadcastIntOpsTuple { + private: - Nd4jLong *buffer = new Nd4jLong[100000]; - std::lock_guard lock(_mutex); - #ifndef __CUDABLAS__ - // at this moment we don't have streams etc, so let's just skip this for now - _rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); - #endif - // if(_rng != nullptr) + public: + nd4j::scalar::IntOps s; + nd4j::pairwise::IntOps p; + nd4j::broadcast::IntOps b; + + BroadcastIntOpsTuple() = default; + ~BroadcastIntOpsTuple() = default; + + BroadcastIntOpsTuple(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastIntOpsTuple custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast); + }; } -ProviderRNG& ProviderRNG::getInstance() { - - static ProviderRNG instance; - return instance; -} -random::RandomBuffer* ProviderRNG::getRNG() const { - - return _rng; -} - -std::mutex ProviderRNG::_mutex; - -} +#endif //DEV_TESTS_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp new file mode 100644 index 000000000..ff72ff4b9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/bits_hamming_distance.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_bits_hamming_distance) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length"); + REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type"); + + helpers::hamming(block.launchContext(), *x, *y, *output); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(bits_hamming_distance) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64)); + } + + DECLARE_TYPES(bits_hamming_distance) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(0, {ALL_INDICES}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp index 2aac5c6f9..89d380d02 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(cyclic_rshift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_rshift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - helpers::cyclic_rshift_bits(block.launchContext(), *input, *output, shift); - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_rshift_bits: can't shift beyond size of data type") + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp index 0bdb9503d..f18314910 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift); - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type") + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp index 4068351a2..36b0defd0 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(rshift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "rshift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "rshift_bits: can't shift beyond size of data type") - - helpers::rshift_bits(block.launchContext(), *input, *output, shift); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp index f79da1024..ab4ed9880 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "shift_bits: can't shift beyond size of data type") - - helpers::shift_bits(block.launchContext(), *input, *output, shift); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/convo/ismax.cpp b/libnd4j/include/ops/declarable/generic/convo/ismax.cpp index ad5a485e1..13de73e81 100644 --- a/libnd4j/include/ops/declarable/generic/convo/ismax.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/ismax.cpp @@ -45,7 +45,7 @@ DECLARE_SYN(IsMax, ismax); DECLARE_TYPES(ismax) { getOpDescriptor() ->setAllowedInputTypes(0, DataType::ANY) - ->setAllowedOutputTypes(0, DataType::BOOL); + ->setAllowedOutputTypes(0, DataType::ANY); } diff --git a/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp b/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp index 2ae69e296..21906f4eb 100644 --- a/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp +++ b/libnd4j/include/ops/declarable/generic/nlp/cbow.cpp @@ -84,7 +84,8 @@ namespace nd4j { ->setAllowedInputTypes(11, nd4j::DataType::INT64) ->setAllowedInputTypes(12, nd4j::DataType::INT32) ->setAllowedInputTypes(13, nd4j::DataType::INT32) - ->setAllowedInputTypes(14, {ALL_FLOATS}); + ->setAllowedInputTypes(14, {ALL_FLOATS}) + ->setAllowedOutputTypes(nd4j::DataType::ANY); } } } diff --git a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp index 78c6e3818..a97e1a79e 100644 --- a/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp +++ b/libnd4j/include/ops/declarable/generic/nlp/skipgram.cpp @@ -79,7 +79,7 @@ namespace nd4j { ->setAllowedInputTypes(9, {ALL_FLOATS}) ->setAllowedInputTypes(10, nd4j::DataType::INT64) ->setAllowedInputTypes(11, {ALL_FLOATS}) - ->setAllowedOutputTypes(nd4j::DataType::INT8); + ->setAllowedOutputTypes(nd4j::DataType::ANY); } /* diff --git a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp index 08dba09f2..d96f97c10 100644 --- a/libnd4j/include/ops/declarable/generic/nn/softmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/softmax.cpp @@ -70,7 +70,7 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) { DECLARE_TYPES(softmax_bp) { getOpDescriptor() - ->setAllowedInputTypes(DataType::ANY) + ->setAllowedInputTypes({ALL_FLOATS}) ->setAllowedOutputTypes({ALL_FLOATS}); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp index 8ae503ba7..fa95997be 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/broadcast_dynamic_shape.cpp @@ -15,54 +15,79 @@ ******************************************************************************/ // -// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include #if NOT_EXCLUDED(OP_broadcast_dynamic_shape) -//#include #include -#include namespace nd4j { - namespace ops { - DECLARE_TYPES(broadcast_dynamic_shape) { - getOpDescriptor() - ->setAllowedOutputTypes({ALL_INTS}) - ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); - } +namespace ops { - CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) { - auto x_shape = INPUT_VARIABLE(0); - auto y_shape = INPUT_VARIABLE(1); - - REQUIRE_TRUE(shape::isVector(x_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The first argument should be a vector"); - REQUIRE_TRUE(shape::isVector(y_shape->shapeInfo(), 1), 0, "broadcast_dynamic_shape: The second argument should be a vector"); +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(broadcast_dynamic_shape, 2, 1, false, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - - return helpers::bdsFunctor(block.launchContext(), x_shape, y_shape, output); - } + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); - DECLARE_SHAPE_FN(broadcast_dynamic_shape) { - auto shapeList = SHAPELIST(); - - auto theFirst = inputShape->at(0); - auto theSecond = inputShape->at(1); + auto z = OUTPUT_VARIABLE(0); - auto theFirstLen = shape::sizeAt(theFirst, -1); - auto theSecondLen = shape::sizeAt(theSecond, -1); + REQUIRE_TRUE(x->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the first input array must have rank = 1, but got %i instead!", x->rankOf()); + REQUIRE_TRUE(y->rankOf() == 1, 0, "BROADCAST_DYNAMIC_SHAPE OP: the second input array must have rank = 1, but got %i instead!", y->rankOf()); + REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "BROADCAST_DYNAMIC_SHAPE OP: both input arrays must have the same integer type !"); - auto shapeLength = nd4j::math::nd4j_max(theFirstLen, theSecondLen); + // contract shapeInfos, neglect and don't fill strides, ews, order + // shapes are of interest only + std::vector xShapeInfo(shape::shapeInfoLength(x->lengthOf())); + std::vector yShapeInfo(shape::shapeInfoLength(y->lengthOf())); - auto newshape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shapeLength, ArrayOptions::dataType(theFirst)); - shapeList->push_back(newshape); - return shapeList; - } + // fill rank and data type + xShapeInfo[0] = x->lengthOf(); + yShapeInfo[0] = y->lengthOf(); + ArrayOptions::setDataType(xShapeInfo.data(), nd4j::DataType::INT64); // fill with some data type, it doesn't matter what type exactly to choose + ArrayOptions::setDataType(yShapeInfo.data(), nd4j::DataType::INT64); - } + for (Nd4jLong i = 0; i < x->lengthOf(); ++i) + xShapeInfo[i + 1] = x->e(i); + + for (Nd4jLong i = 0; i < y->lengthOf(); ++i) + yShapeInfo[i + 1] = y->e(i); + + Nd4jLong* poinerOnOutShapeInfo = nullptr; + + const bool isBroadcastPossible = ShapeUtils::evalBroadcastShapeInfo(xShapeInfo.data(), yShapeInfo.data(), true, poinerOnOutShapeInfo, block.launchContext()->getWorkspace()); + + REQUIRE_TRUE(isBroadcastPossible, 0, "BROADCAST_DYNAMIC_SHAPE OP: the shapes of two input arrays %s and %s are not suitable for broadcast operation !", ShapeUtils::shapeAsString(xShapeInfo.data()).c_str(), ShapeUtils::shapeAsString(yShapeInfo.data()).c_str()); + + for (Nd4jLong i = 0; i < z->lengthOf(); ++i) + z->p(i, poinerOnOutShapeInfo[i + 1]); + + return Status::OK(); +} + +DECLARE_TYPES(broadcast_dynamic_shape) { + getOpDescriptor() + ->setAllowedOutputTypes({ALL_INTS}) + ->setAllowedInputTypes({ALL_INTS}); +} + + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(broadcast_dynamic_shape) { + + const int xRank = INPUT_VARIABLE(0)->lengthOf(); + const int yRank = INPUT_VARIABLE(1)->lengthOf(); + + const int maxRank = xRank > yRank ? xRank : yRank; + + auto outputShapeInfo = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxRank, ArrayOptions::dataType(inputShape->at(0))); + + return SHAPELIST(outputShapeInfo); +} + +} } #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp index 0bcbd2439..0d9465b02 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp @@ -35,6 +35,9 @@ OP_IMPL(scatter_add, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); @@ -68,10 +71,8 @@ OP_IMPL(scatter_add, 3, 1, true) { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp index a711916a1..dccc34e59 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp @@ -16,7 +16,7 @@ // // @author Created by raver119 on 24.11.17. -// @author Yurii Shyrma (iuriish@yahoo.com) +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -28,21 +28,24 @@ namespace nd4j { namespace ops { OP_IMPL(scatter_div, 3, 1, true) { - + auto input = INPUT_VARIABLE(0); auto indices = INPUT_VARIABLE(1); auto updates = INPUT_VARIABLE(2); auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_DIV OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -50,28 +53,27 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_DIV OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); + if (!indices->isEmpty()) + // ScatterHelper::template scatterApply>(output, indices, updates); + helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); return Status::OK(); } DECLARE_SYN(ScatterDiv, scatter_div); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp index a9f0ab889..5d37a71d0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp @@ -35,14 +35,17 @@ OP_IMPL(scatter_max, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MAX OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -50,28 +53,26 @@ OP_IMPL(scatter_max, 3, 1, true) { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MAX OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); - helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); - return Status::OK(); } DECLARE_SYN(ScatterMax, scatter_max); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp index cce22b6fb..1bed296f9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp @@ -35,14 +35,17 @@ OP_IMPL(scatter_min, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MIN OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -50,27 +53,25 @@ OP_IMPL(scatter_min, 3, 1, true) { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MIN OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp index 02eebb50c..46b9f7008 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp @@ -39,9 +39,13 @@ namespace nd4j { const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + + if (!block.isInplace()) + output->assign(input); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MUL OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -49,27 +53,25 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MUL OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp index de2bf4fa4..cf3745236 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp @@ -34,14 +34,17 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_SUB OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -49,29 +52,27 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_SUB OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + // ScatterHelper::template scatterApply>(output, indices, updates); + helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp index bc13581bf..55076e51e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp @@ -33,14 +33,17 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_UPD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -48,28 +51,26 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_UPD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + // ScatterHelper::template scatterApply>(output, indices, updates); + helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp index ea2e3330a..bd16cdd79 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp @@ -87,7 +87,7 @@ namespace nd4j { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(0, nd4j::DataType::ANY) - ->setAllowedOutputTypes(1, {ALL_INTS}); + ->setAllowedOutputTypes(1, {ALL_INDICES}); } } } diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 5249758bf..3c165f64f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -167,9 +167,7 @@ DECLARE_SHAPE_FN(concat) { } for(int i = 1; i < numOfArrs; ++i) - if (!shape::isEmpty(arrShapes[i])) { - outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; - } + outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0])); diff --git a/libnd4j/include/ops/declarable/generic/transforms/eye.cpp b/libnd4j/include/ops/declarable/generic/transforms/eye.cpp index ef1d8e1e5..c5f3dbff6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/eye.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/eye.cpp @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { - CUSTOM_OP_IMPL(eye, -2, 1, false, 0, -2) { + CUSTOM_OP_IMPL(eye, -2, 1, false, -2, -2) { helpers::eye(block.launchContext(), *OUTPUT_VARIABLE(0)); @@ -44,8 +44,7 @@ namespace ops { std::vector params; - // FIX ME: original has a dtype param - so should be used here instead. e.g. (DataType) INT_ARG(0); - nd4j::DataType dtype = nd4j::DataType::FLOAT32; + nd4j::DataType dtype = block.getTArguments()->empty() ? nd4j::DataType::FLOAT32 : nd4j::DataTypeUtils::fromInt(T_ARG(0)); if(block.width() == 0) { params = *block.getIArguments(); @@ -54,27 +53,27 @@ namespace ops { for (int i = 0; i < block.width(); i++) { auto input = INPUT_VARIABLE(i); REQUIRE_TRUE(input->rankOf() == 1, 0, "Inputs to eye should be 1D"); - for (int e = 0; e < input->lengthOf(); e++) { + + for (int e = 0; e < input->lengthOf(); e++) params.emplace_back(input->e(e)); - } } } - REQUIRE_TRUE(params.size() > 0, 0, "Size not provided for eye op."); + REQUIRE_TRUE(params.size() > 0, 0, "Size is not provided for eye op."); const bool ordered = (params[0] == -99 || params[0] == -102); // -99 :'c', -102 : 'f' if (!ordered) params.insert(params.begin(), -99); - REQUIRE_TRUE(params.size() > 1, 0, "Size not provided for eye op."); + REQUIRE_TRUE(params.size() > 1, 0, "Size is not provided for eye op."); Nd4jLong* outShapeInfo(nullptr); const int size = params.size(); switch(size) { - + case 2: ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong); outShapeInfo[0] = 2; @@ -99,7 +98,7 @@ namespace ops { outShapeInfo[i] = params[i+2]; break; } - + shape::updateStrides(outShapeInfo, static_cast(-params[0])); auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, dtype)); RELEASE(outShapeInfo, block.getWorkspace()); diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp index b3063c75d..529446e12 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram_fixed_width.cpp @@ -49,7 +49,7 @@ CUSTOM_OP_IMPL(histogram_fixed_width, 2, 1, false, 0, 0) { DECLARE_TYPES(histogram_fixed_width) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); + ->setAllowedOutputTypes({ALL_INDICES}); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp index 4e612565e..8ab5fa32f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp @@ -32,12 +32,19 @@ namespace ops { auto input = INPUT_VARIABLE(0); auto gain = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); - + std::vector axis = *block.getIArguments(); + const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + NDArray* bias = nullptr; - if (block.width() > 2) + if (block.width() > 2) { bias = INPUT_VARIABLE(2); + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); + } std::vector longAxis = ArrayUtils::toLongVector(axis); @@ -48,9 +55,12 @@ namespace ops { std::vector bargs = {}; standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); - output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output); - if(bias != nullptr) - output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); + // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, output); + output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain); + if(bias != nullptr) { + // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); + output->applyBroadcast(nd4j::broadcast::Add, {dimC}, bias); + } return Status::OK(); } @@ -71,12 +81,20 @@ namespace ops { auto dLdg = OUTPUT_VARIABLE(1); auto dLdb = block.width() == 4 ? OUTPUT_VARIABLE(2) : nullptr; + const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC + const int dimC = isNCHW ? 1 : input->rankOf() - 1; + + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + std::vector axis = *block.getIArguments(); std::vector longAxis = ArrayUtils::toLongVector(axis); - if(bias != nullptr) - eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); + if(bias != nullptr) { + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); + // eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); + eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); + } NDArray standardized(input->shapeInfo(), false, block.launchContext()); @@ -88,10 +106,11 @@ namespace ops { standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr); - standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, {0}, true); + standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); nd4j::ops::standardize_bp standardizeBp; - eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); + // eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); + eps->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain, dLdx); auto dLdx_tmp = dLdx->dup(); std::vector standardizeBpArgs = {input, dLdx_tmp}; diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index 900d42816..a6362a73f 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -45,7 +45,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_shift_bits) - DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0); #endif /** @@ -56,7 +56,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_rshift_bits) - DECLARE_CONFIGURABLE_OP(rshift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0); #endif /** @@ -67,7 +67,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_cyclic_shift_bits) - DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0); #endif /** @@ -78,7 +78,18 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); + #endif + + /** + * This operation returns hamming distance based on bits + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_bits_hamming_distance) + DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0); #endif } } diff --git a/libnd4j/include/ops/declarable/headers/transforms.h b/libnd4j/include/ops/declarable/headers/transforms.h index 75715f78e..b6fd57112 100644 --- a/libnd4j/include/ops/declarable/headers/transforms.h +++ b/libnd4j/include/ops/declarable/headers/transforms.h @@ -111,18 +111,21 @@ namespace nd4j { /** * creates identity 2D matrix or batch of identical 2D identity matrices - * + * * Input array: * provide some array - in any case operation simply neglects it - * + * + * Input float argument (if passed): + * TArgs[0] - type of elements of output array, default value is 5 (float) + * * Input integer arguments: * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order * IArgs[1] - the number of rows in output inner-most 2D identity matrix * IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows - * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape + * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape */ #if NOT_EXCLUDED(OP_eye) - DECLARE_CUSTOM_OP(eye, -2, 1, false, 0, 2); + DECLARE_CUSTOM_OP(eye, -2, 1, false, -2, 2); #endif #if NOT_EXCLUDED(OP_gather_nd) @@ -143,10 +146,10 @@ namespace nd4j { /** * clip a list of given tensors with given average norm when needed - * + * * Input: * a list of tensors (at least one) - * + * * Input floating point argument: * clip_norm - a value that used as threshold value and norm to be used * @@ -182,12 +185,12 @@ namespace nd4j { /** * returns histogram (as 1D array) with fixed bins width - * + * * Input arrays: - * - input array with elements to be binned into output histogram + * - input array with elements to be binned into output histogram * - range array with first element being bottom limit and second element being top limit of histogram, please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1] - * + * * Input integer arguments: * nbins (optional) - number of histogram bins, default value is 100 */ diff --git a/libnd4j/include/ops/declarable/helpers/cpu/bds.cpp b/libnd4j/include/ops/declarable/helpers/cpu/bds.cpp deleted file mode 100644 index fd888ee87..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/bds.cpp +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author GS -// - -#include -#include - - -namespace nd4j { -namespace ops { -namespace helpers { - - Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) { - - - if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case - // lenght are equals - if (x_shape->lengthOf() == y_shape->lengthOf()) { - auto greater = (x_shape->e(0) < y_shape->e(0) ? y_shape : x_shape); - output->assign(greater); - } - else { - auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape); - auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape); - output->assign(greater); - auto lastG = greater->lengthOf() - 1; - auto lastL = lesser->lengthOf() - 1; - if (greater->e(lastG) < lesser->e(lastL)) - output->p(lastG, lesser->e(lastL)); - } - } - else { - //int e = 0, x = 0, y = 0; - Nd4jLong xLen = x_shape->lengthOf(); - Nd4jLong yLen = y_shape->lengthOf(); - Nd4jLong zLen = output->lengthOf(); - Nd4jLong borderLen = nd4j::math::nd4j_min(xLen, yLen); - for (Nd4jLong e = 0; e < zLen; e++) { - Nd4jLong val; - if (e < borderLen) { - val = nd4j::math::nd4j_max(x_shape->e(e), y_shape->e(e)); - } else if (e < xLen) { - val = nd4j::math::nd4j_max(x_shape->e(e), y_shape->e(yLen - 1)); - } else { - val = nd4j::math::nd4j_max(x_shape->e(xLen - 1), y_shape->e(e)); - } - - output->p(e, val); - } - } - return Status::OK(); - } - -} -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp b/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp new file mode 100644 index 000000000..660bd9354 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/hamming.cpp @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + static void _hamming(NDArray &x, NDArray &y, NDArray &z) { + auto xEws = x.ews(); + auto yEws = y.ews(); + + auto xBuffer = x.bufferAsT(); + auto yBuffer = y.bufferAsT(); + + Nd4jLong distance = 0; + + if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) { + PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance) + for (Nd4jLong e = 0; e < x.lengthOf(); e++) { + auto _x = static_cast(xBuffer[e]); + auto _y = static_cast(yBuffer[e]); + + distance += __builtin_popcountll(_x ^ _y); + } + + } else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) { + PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance) + for (Nd4jLong e = 0; e < x.lengthOf(); e++) { + auto _x = static_cast(xBuffer[e * xEws]); + auto _y = static_cast(yBuffer[e * yEws]); + + distance += __builtin_popcountll(_x ^ _y); + } + } else { + PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance) + for (Nd4jLong e = 0; e < x.lengthOf(); e++) { + auto _x = static_cast(x.e(e)); + auto _y = static_cast(y.e(e)); + + distance += __builtin_popcountll(_x ^ _y); + } + } + + z.p(0, distance); + } + + void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { + BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (x, y, output), INTEGER_TYPES, INDEXING_TYPES); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp index a3f3e2f9e..349d0381a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogramFixedWidth.cpp @@ -28,13 +28,10 @@ namespace helpers { template void histogramFixedWidth_(const NDArray& input, const NDArray& range, NDArray& output) { - const int nbins = output.lengthOf(); + const int nbins = output.lengthOf(); - // firstly initialize output with zeros - if(output.ews() == 1) - memset(output.buffer(), 0, nbins * output.sizeOfT()); - else - output = 0; + // firstly initialize output with zeros + output.nullify(); const T leftEdge = range.e(0); const T rightEdge = range.e(1); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index 45024b5cb..d673e64bd 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -39,11 +39,28 @@ namespace helpers { template static void reluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x > (T)0.f ? y : T(0.f); + + T zero = (T) 0.f; + auto functor = LAMBDA_TT(x, y, zero){ + return x > zero ? y : zero; }; input->applyPairwiseLambda(epsilon, functor, output); + + /* + auto x = input->bufferAsT(); + auto y = epsilon->bufferAsT(); + auto z = output->bufferAsT(); + + int length = input->lengthOf(); + + T zero = (T) 0.f; + + PRAGMA_OMP_PARALLEL_FOR + for (int e = 0; e < length; e++) { + z[e] = x[e] > zero ? y[e] : zero; + } + */ } void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index c1d01930c..0b16ac989 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -54,6 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) std::vector dimsToExcludeUpd(sizeOfDims); std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); + shape::printIntArray(dimsToExcludeUpd.data(),dimsToExcludeUpd.size()); + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu index 80e5d885b..f402944aa 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/activations.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/activations.cu @@ -24,6 +24,7 @@ #include #include #include +#include namespace nd4j { namespace ops { @@ -196,7 +197,7 @@ void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& /////////////////////////////////////////////////////////////////// template -__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { +__device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { // logic of this kernel is based on assumption gridDim = 1 @@ -210,7 +211,7 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo if (threadIdx.x == 0) { extern __shared__ char shared[]; shmem = reinterpret_cast(shared); - len = shape::length(xzShapeInfo); + len = shape::length(xShapeInfo); numOfIters = (len + blockDim.x - 1) / blockDim.x; // ceil (len / blockDim.x) } __syncthreads(); @@ -222,8 +223,8 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); - shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp + const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len); + shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : nd4j::math::nd4j_max(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp } else shmem[threadIdx.x] = -DataTypeUtils::max(); // FIXME: what if T is unsigned ?? @@ -248,9 +249,10 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; if(elemIdx < len) { - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); - z[offset] = nd4j::math::nd4j_exp(x[offset] - max); - shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp + const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len); + const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len); + z[zOffset] = nd4j::math::nd4j_exp(x[xOffset] - max); + shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp } else shmem[threadIdx.x] = 0; @@ -270,43 +272,87 @@ __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo for (int i = 0; i < numOfIters; ++i) { const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; if(elemIdx >= len) continue; - const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); - z[offset] /= shmem[0]; + const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len); + z[zOffset] /= shmem[0]; } } +template +__global__ void softMaxForVectorCudaGlobal(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { + + softMaxForVectorCuda(vx, xShapeInfo, vz, zShapeInfo); +} + /////////////////////////////////////////////////////////////////// template -linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { +linkage void softMaxForVectorCudaLauncher(const cudaStream_t* stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - softMaxForVectorCuda<<<1, MAX_NUM_THREADS, MAX_NUM_THREADS * sizeof(T) + 512, *stream>>>(vx, xzShapeInfo, vz); + softMaxForVectorCudaGlobal<<<1, MAX_NUM_THREADS / 4 , (MAX_NUM_THREADS / 4) * sizeof(T) + 512, *stream>>>(vx, xShapeInfo, vz, zShapeInfo); } +/////////////////////////////////////////////////////////////////// +template +__global__ static void softMaxCuda(const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) { + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + const auto* xTad = x + xOffsets[blockIdx.x]; + auto* zTad = z + zOffsets[blockIdx.x]; + + softMaxForVectorCuda(xTad, xTadShapeInfo, zTad, zTadShapeInfo); +} + +/////////////////////////////////////////////////////////////////// +template +static void softMaxCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, + void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets) { + + softMaxCuda<<>>(vx, xTadShapeInfo, xOffsets, vz, zTadShapeInfo, zOffsets); +} + + ////////////////////////////////////////////////////////////////////////// void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& output, const int dimension) { if(!input.isActualOnDeviceSide()) input.syncToDevice(); const int rank = input.rankOf(); + PointersManager manager(context, "helpers::softmax"); + if(input.isVector()) { if(rank == 1 || input.sizeAt(dimension) != 1) { - BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer()), FLOAT_TYPES); - input.tickReadDevice(); + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), softMaxForVectorCudaLauncher, (context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); } else output = 1.; } else { - auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); - (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily - auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); - output /= sumAlongDim; - input.tickReadDevice(); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), {dimension}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), {dimension}); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = packZ.numberOfTads(); + const int sharedMem = input.sizeOfT() * threadsPerBlock + 512; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), softMaxCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), output.specialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets()), FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + // auto maxAlongDim = const_cast(input).reduceAlongDims(reduce::Max, {dimension}, true); + // (input - maxAlongDim).applyTransform(transform::Exp, &output); // output contains exponents temporarily + // auto sumAlongDim = output.reduceAlongDims(reduce::Sum, {dimension}, true); + // output /= sumAlongDim; + // input.tickReadDevice(); } - PointersManager manager(context, "helpers::softmax"); + manager.synchronize(); output.tickWriteDevice(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/bds.cu b/libnd4j/include/ops/declarable/helpers/cuda/bds.cu deleted file mode 100644 index ef501eac0..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/bds.cu +++ /dev/null @@ -1,113 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author GS -// - -#include -#include - - -namespace nd4j { -namespace ops { -namespace helpers { - - - template - static __global__ void bdsLoopKernel(void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) { - __shared__ T const* x; - __shared__ T const* y; - __shared__ T* z; - __shared__ bool speedWay; - //__shared__ int indexX, indexY; - __shared__ Nd4jLong xLen, yLen, outputLen; - if (threadIdx.x == 0) { - x = reinterpret_cast(inputX); - y = reinterpret_cast(inputY); - z = reinterpret_cast(output); - xLen = shape::length(inputXshape); - yLen = shape::length(inputYshape); - outputLen = shape::length(outputShape); - speedWay = true; - speedWay = speedWay && (shape::elementWiseStride(inputXshape) == 1); - speedWay = speedWay && (shape::elementWiseStride(inputYshape) == 1); - speedWay = speedWay && (shape::elementWiseStride(outputShape) == 1); - - } - __syncthreads(); - - auto tid = threadIdx.x + blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - for (int e = tid; e < outputLen; e += step) { - T val; - if (speedWay) { - if (e < nd4j::math::nd4j_min(yLen, xLen)) { - val = nd4j::math::nd4j_max(x[e], y[e]); - } else if (e < xLen) { - val = nd4j::math::nd4j_max(x[e], y[yLen - 1]); - } else { - val = nd4j::math::nd4j_max(x[xLen - 1], y[e]); - } - z[e] = val; - } - else { - auto xIndex = e < xLen?shape::getIndexOffset(e, inputXshape, xLen):shape::getIndexOffset(xLen, inputXshape, xLen); - auto yIndex = e < yLen?shape::getIndexOffset(e, inputYshape, yLen):shape::getIndexOffset(yLen - 1, inputYshape, yLen); - auto zIndex = shape::getIndexOffset(e, outputShape, outputLen); - z[zIndex] = nd4j::math::nd4j_max(x[xIndex], y[yIndex]); - } - } - } - - template - static void bdsLoopH(cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) { - bdsLoopKernel<<<1, 256, 512, *stream>>>(inputX, inputXshape, inputY, inputYshape, output, outputShape); - - } - - Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) { - //int e = 0, x = 0, y = 0; - NDArray::prepareSpecialUse({output}, {x_shape, y_shape}); - if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case - x_shape->syncToHost(); y_shape->syncToHost(); - if (x_shape->lengthOf() == y_shape->lengthOf()) { - auto greater = (x_shape->e(0) < y_shape->e(0) ? y_shape : x_shape); - output->assign(greater); - } - else { - auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape); - auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape); - output->assign(greater); - auto lastG = greater->lengthOf() - 1; - auto lastL = lesser->lengthOf() - 1; - if (greater->e(lastG) < lesser->e(lastL)) - output->p(lastG, lesser->e(lastL)); - output->syncToDevice(); - } - } - else { - //bdsLoopH(context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), y->getSpecialBuffer(), y->getSpecialShape(), output->specialBuffer(), output->specialShapeInfo()) - BUILD_SINGLE_SELECTOR(output->dataType(), bdsLoopH, (context->getCudaStream(), x_shape->getSpecialBuffer(), x_shape->getSpecialShapeInfo(), y_shape->getSpecialBuffer(), y_shape->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES); - } - NDArray::registerSpecialUse({output}, {x_shape, y_shape}); - return Status::OK(); - return Status::OK(); - } - -} -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu b/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu index 513911f97..12f14b20b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/confusion.cu @@ -30,10 +30,10 @@ namespace helpers { template __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong bufferLength) { - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (int t = tid; t < bufferLength; t += step) { - destination[t] = reinterpret_cast(source)[t]; + destination[t] = static_cast(reinterpret_cast(source)[t]); } } @@ -51,38 +51,24 @@ namespace helpers { } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (int t = tid; t < bufferLength; t += step) { - //auto tX = reinterpret_cast(inputList[t]); - //auto xShape = reinterpret_cast(inputShapeList[t]); auto label = labelsBuffer[t]; //->e(j); auto pred = predictionBuffer[t]; //->e(j); auto tZ = z + tadOffsets[label]; T val = (weightsBuffer == nullptr ? (T)1.0f : w[t]); - //for (int e = threadIdx.x; e < arrLen; e += blockDim.x) { - - tZ[shape::getIndexOffset(pred, tadShape, arrLen)] = val; //tX[shape::getIndexOffset(e, , arrLen)]; + auto idx = shape::getIndexOffset(pred, tadShape, arrLen); + tZ[idx] = val; } } - template + template void _confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { -// std::unique_ptr arrs(output->allTensorsAlongDimension({1})); -// -//#pragma omp parallel for if(labels->lengthOf() > Environment::getInstance()->elementwiseThreshold()) schedule(static) -// for (int j = 0; j < labels->lengthOf(); ++j){ -// auto label = labels->e(j); -// auto pred = predictions->e(j); -// T value = (weights == nullptr ? (T)1.0f : weights->e(j)); -// (*arrs->at(label)).p(pred, value); -// } - - int dimension = 1; - - auto pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimension); + auto stream = context->getCudaStream(); + auto pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), 1); PointersManager manager(context, "helpers::confusion"); @@ -90,26 +76,26 @@ namespace helpers { Nd4jLong* predictionLongBuffer = predictions->dataType() == nd4j::DataType::INT64?(Nd4jLong*)predictions->specialBuffer():nullptr; if (labelsLongBuffer == nullptr) { - cudaError_t err = cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong)); + auto err = cudaMalloc(&labelsLongBuffer, labels->lengthOf() * sizeof(Nd4jLong)); if (err != 0) throw nd4j::cuda_exception::build("Cannot allocate memory for labels long buffer", err); // copy with type conversion - copyBuffers<<<256, 512, 8192>>>(labelsLongBuffer, labels->getSpecialBuffer(), labels->lengthOf()); + copyBuffers<<<256, 512, 1024, *stream>>>(labelsLongBuffer, labels->getSpecialBuffer(), labels->lengthOf()); } if (predictionLongBuffer == nullptr) { - cudaError_t err = cudaMalloc(&predictionLongBuffer, predictions->lengthOf() * sizeof(Nd4jLong)); + auto err = cudaMalloc(&predictionLongBuffer, predictions->lengthOf() * sizeof(Nd4jLong)); if (err != 0) throw nd4j::cuda_exception::build("Cannot allocate memory for predictions long buffer", err); // copy with type conversion - copyBuffers<<<256, 512, 8192>>>(predictionLongBuffer, predictions->getSpecialBuffer(), predictions->lengthOf()); + copyBuffers<<<256, 512, 1024, *stream>>>(predictionLongBuffer, predictions->getSpecialBuffer(), predictions->lengthOf()); } auto bufferLength = labels->lengthOf(); dim3 launchDims(32, 32, 1024); - auto stream = context->getCudaStream(); - confusionFunctorKernel<<>>(labelsLongBuffer, predictionLongBuffer, - bufferLength, weights != nullptr? weights->getSpecialBuffer():nullptr, output->specialBuffer(), pack.specialShapeInfo(), pack.specialOffsets()); + confusionFunctorKernel<<>>(labelsLongBuffer, predictionLongBuffer, bufferLength, weights != nullptr? weights->getSpecialBuffer():nullptr, output->specialBuffer(), pack.specialShapeInfo(), pack.specialOffsets()); + + manager.synchronize(); if (predictionLongBuffer != predictions->getSpecialBuffer()) { cudaError_t err = cudaFree(predictionLongBuffer); @@ -122,17 +108,15 @@ namespace helpers { if (err != 0) throw nd4j::cuda_exception::build("Cannot deallocate memory for labels long buffer", err); } - manager.synchronize(); } void confusionFunctor(nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output) { - auto xType = output->dataType(); // weights can be null - - BUILD_SINGLE_SELECTOR(xType, _confusionFunctor, (context, labels, predictions, weights, output), NUMERIC_TYPES); + auto xType = predictions->dataType(); + auto zType = output->dataType(); // weights can be null + NDArray::prepareSpecialUse({output}, {labels, predictions, weights}); + BUILD_DOUBLE_SELECTOR(xType, zType, _confusionFunctor, (context, labels, predictions, weights, output), INDEXING_TYPES, NUMERIC_TYPES); + NDArray::registerSpecialUse({output}, {labels, predictions, weights}); } - - BUILD_SINGLE_TEMPLATE(template void _confusionFunctor, (nd4j::LaunchContext * context, NDArray* labels, NDArray* predictions, NDArray* weights, NDArray* output);, NUMERIC_TYPES); - } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index 857ebed38..92e5b38b4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -50,8 +50,8 @@ namespace nd4j { auto zShapeInfo = zShapeInfos[o]; auto zLength = shape::length(zShapeInfo); - // iLimit should be - auto iLimit = iLength <= blockIdx.x ? blockIdx.x : (iLength + (blockIdx.x - (iLength % blockIdx.x))); + // iLimit should be multiple of blockDim.x + auto iLimit = iLength <= blockDim.x ? blockDim.x : (iLength + (blockDim.x - (iLength % blockDim.x))); int cnt = 0; for (Nd4jLong e = threadIdx.x; e < iLimit; e += blockDim.x) { @@ -75,8 +75,9 @@ namespace nd4j { // doing actual update if (e < iLength) - if (trueIndices[threadIdx.x] >= 0) + if (trueIndices[threadIdx.x] >= 0) { z[trueIndices[threadIdx.x]] = x[shape::getIndexOffset(e, xShapeInfo, xLength)]; + } __syncthreads(); } @@ -148,13 +149,12 @@ namespace nd4j { auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); - dynamicPartitionTadKernel<<<256, 512, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); + dynamicPartitionTadKernel<<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), shape::length(packX.primaryShapeInfo()), indices->getSpecialBuffer(), indices->getSpecialShapeInfo(), indices->lengthOf(), dOutBuffers, dOutTadShapes, dOutTadOffsets, outSize); } else { auto numThreads = 256; auto shmemSize = numThreads * sizeof(Y) * 2 + 1024; - std::vector outBuffers; std::vector outShapes; @@ -203,6 +203,9 @@ namespace nd4j { auto indices = reinterpret_cast(vindices[e]); auto iShapeInfo = iShapeInfos[e]; + if (shape::isEmpty(iShapeInfo)) + continue; + auto iLength = shape::length(iShapeInfo); auto zLength = shape::length(zTadShapeInfo); @@ -310,8 +313,9 @@ namespace nd4j { NDArray::registerSpecialUse({}, {indices, input}); - for (auto v:outputList) + for (auto v:outputList) { v->tickWriteDevice(); + } } template diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu index 709f0ed2c..6587b4ca7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu @@ -106,6 +106,7 @@ namespace nd4j { const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, xCoordStart, xRank); z[zOffset] = x[xOffset]; + printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); } } @@ -124,7 +125,7 @@ namespace nd4j { const int maxRank = nd4j::math::nd4j_max(indices.rankOf(), nd4j::math::nd4j_max(input.rankOf(), output.rankOf())); - const int threadsPerBlock = MAX_NUM_THREADS; + const int threadsPerBlock = 256; const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int sharedMem = 8 * threadsPerBlock * maxRank + 128; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu b/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu new file mode 100644 index 000000000..3bc30e373 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/hamming.cu @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + static _CUDA_G void _hammingKernel(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, void *reductionBuffer, Nd4jLong length) { + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong *shared; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + shared = reinterpret_cast(shmem); + } + __syncthreads(); + + // we want to nullify temporary memory before accumulating intermediate results + shared[threadIdx.x] = 0; + + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) { + auto _x = static_cast(x[shape::getIndexOffset(e, xShapeInfo, length)]); + auto _y = static_cast(y[shape::getIndexOffset(e, yShapeInfo, length)]); + + // we save intermediate result into shared memory + shared[threadIdx.x] += __popcll(_x ^ _y); + } + __syncthreads(); + + // now we accumulate values + auto numItems = nd4j::math::nd4j_min(blockDim.x, length); + auto floorPow2 = numItems; + if (floorPow2 & (floorPow2 - 1)) { + + while (floorPow2 & (floorPow2 - 1)) + floorPow2 &= floorPow2 - 1; + + if (threadIdx.x >= floorPow2) + shared[threadIdx.x - floorPow2] = shared[threadIdx.x - floorPow2] + shared[threadIdx.x]; + + __syncthreads(); + } + __syncthreads(); + + for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) { + if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems) + shared[threadIdx.x] = shared[threadIdx.x] + shared[threadIdx.x + activeThreads]; + + __syncthreads(); + } + __syncthreads(); + + // FIXME: do we really want atomicAdd on global memory here + // and store them to output + if (threadIdx.x == 0 && shared[0] > 0) + nd4j::math::atomics::nd4j_atomicAdd(&z[0], static_cast(shared[threadIdx.x])); + } + + template + static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &z) { + _hammingKernel<<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>(x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf()); + } + + void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) { + NDArray::prepareSpecialUse({&output}, {&x, &y}); + BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&x, &y}); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu index 2c46210cf..317f1d857 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogramFixedWidth.cu @@ -20,110 +20,181 @@ #include #include +#include -namespace nd4j { -namespace ops { +namespace nd4j { +namespace ops { namespace helpers { - template - __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape, bufferLength)]; - } +/////////////////////////////////////////////////////////////////// +template +__global__ static void histogramFixedWidthCuda( const void* vx, const Nd4jLong* xShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, + const X leftEdge, const X rightEdge) { + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong xLen, zLen, totalThreads, nbins; + __shared__ X binWidth, secondEdge, lastButOneEdge; + + if (threadIdx.x == 0) { + + xLen = shape::length(xShapeInfo); + nbins = shape::length(zShapeInfo); // nbins = zLen + totalThreads = gridDim.x * blockDim.x; + + binWidth = (rightEdge - leftEdge ) / nbins; + secondEdge = leftEdge + binWidth; + lastButOneEdge = rightEdge - binWidth; } - template - __global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - for (int t = tid; t < bufferLength; t += step) { - reinterpret_cast(destination)[shape::getIndexOffset(t, destinationShape, bufferLength)] = source[t]; - } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < xLen; i += totalThreads) { + + const X value = x[shape::getIndexOffset(i, xShapeInfo, xLen)]; + + Nd4jLong zIndex; + + if(value < secondEdge) + zIndex = 0; + else if(value >= lastButOneEdge) + zIndex = nbins - 1; + else + zIndex = static_cast((value - leftEdge) / binWidth); + + nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getIndexOffset(zIndex, zShapeInfo, nbins)], 1); } +} - template - static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) { +/////////////////////////////////////////////////////////////////// +template +__host__ static void histogramFixedWidthCudaLauncher(const cudaStream_t *stream, const NDArray& input, const NDArray& range, NDArray& output) { - __shared__ T const* x; - __shared__ Nd4jLong* z; // output buffer + const X leftEdge = range.e(0); + const X rightEdge = range.e(1); - if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); - } - __syncthreads(); - auto tid = blockIdx.x * gridDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; + histogramFixedWidthCuda<<<256, 256, 1024, *stream>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftEdge, rightEdge); +} - for(auto i = tid; i < inputLength; i += step) { +//////////////////////////////////////////////////////////////////////// +void histogramFixedWidth(nd4j::LaunchContext* context, const NDArray& input, const NDArray& range, NDArray& output) { - const T value = x[shape::getIndexOffset(i, inputShape, inputLength)]; - Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); + // firstly initialize output with zeros + output.nullify(); - if(value < secondEdge) - currInd = 0; - else if(value >= lastButOneEdge) - currInd = outputLength - 1; - nd4j::math::atomics::nd4j_atomicAdd(&z[currInd], 1LL); - } - } + PointersManager manager(context, "histogramFixedWidth"); + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogramFixedWidthCudaLauncher, (context->getCudaStream(), input, range, output), LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); +} - template - void histogramFixedWidth_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { - const int nbins = output.lengthOf(); - auto stream = context->getCudaStream(); - // firstly initialize output with zeros - //if(output.ews() == 1) - // memset(output.buffer(), 0, nbins * output.sizeOfT()); - //else - output.assign(0); - if (!input.isActualOnDeviceSide()) - input.syncToDevice(); +// template +// __global__ static void copyBuffers(Nd4jLong* destination, void const* source, Nd4jLong* sourceShape, Nd4jLong bufferLength) { +// const auto tid = blockIdx.x * gridDim.x + threadIdx.x; +// const auto step = gridDim.x * blockDim.x; +// for (int t = tid; t < bufferLength; t += step) { +// destination[t] = reinterpret_cast(source)[shape::getIndexOffset(t, sourceShape, bufferLength)]; +// } +// } - const double leftEdge = range.e(0); - const double rightEdge = range.e(1); +// template +// __global__ static void returnBuffers(void* destination, Nd4jLong const* source, Nd4jLong* destinationShape, Nd4jLong bufferLength) { +// const auto tid = blockIdx.x * gridDim.x + threadIdx.x; +// const auto step = gridDim.x * blockDim.x; +// for (int t = tid; t < bufferLength; t += step) { +// reinterpret_cast(destination)[shape::getIndexOffset(t, destinationShape, bufferLength)] = source[t]; +// } +// } - const double binWidth = (rightEdge - leftEdge ) / nbins; - const double secondEdge = leftEdge + binWidth; - double lastButOneEdge = rightEdge - binWidth; - Nd4jLong* outputBuffer; - cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong)); - if (err != 0) - throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err); - copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, output.getSpecialBuffer(), output.getSpecialShapeInfo(), output.lengthOf()); - histogramFixedWidthKernel<<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge); - returnBuffers<<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.specialShapeInfo(), output.lengthOf()); - //cudaSyncStream(*stream); - err = cudaFree(outputBuffer); - if (err != 0) - throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err); - output.tickWriteDevice(); -//#pragma omp parallel for schedule(guided) -// for(Nd4jLong i = 0; i < input.lengthOf(); ++i) { -// -// const T value = input.e(i); -// -// if(value < secondEdge) -//#pragma omp critical -// output.p(0, output.e(0) + 1); -// else if(value >= lastButOneEdge) -//#pragma omp critical -// output.p(nbins-1, output.e(nbins-1) + 1); -// else { -// Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); -//#pragma omp critical -// output.p(currInd, output.e(currInd) + 1); -// } -// } - } +// template +// static __global__ void histogramFixedWidthKernel(void* outputBuffer, Nd4jLong outputLength, void const* inputBuffer, Nd4jLong* inputShape, Nd4jLong inputLength, double const leftEdge, double binWidth, double secondEdge, double lastButOneEdge) { - void histogramFixedWidth(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); +// __shared__ T const* x; +// __shared__ Nd4jLong* z; // output buffer + +// if (threadIdx.x == 0) { +// z = reinterpret_cast(outputBuffer); +// x = reinterpret_cast(inputBuffer); +// } +// __syncthreads(); +// auto tid = blockIdx.x * gridDim.x + threadIdx.x; +// auto step = blockDim.x * gridDim.x; + +// for(auto i = tid; i < inputLength; i += step) { + +// const T value = x[shape::getIndexOffset(i, inputShape, inputLength)]; +// Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); + +// if(value < secondEdge) +// currInd = 0; +// else if(value >= lastButOneEdge) +// currInd = outputLength - 1; +// nd4j::math::atomics::nd4j_atomicAdd(&z[currInd], 1LL); +// } +// } + + +// template +// void histogramFixedWidth_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { +// const int nbins = output.lengthOf(); +// auto stream = context->getCudaStream(); +// // firstly initialize output with zeros +// //if(output.ews() == 1) +// // memset(output.buffer(), 0, nbins * output.sizeOfT()); +// //else +// output.assign(0); +// if (!input.isActualOnDeviceSide()) +// input.syncToDevice(); + +// const double leftEdge = range.e(0); +// const double rightEdge = range.e(1); + +// const double binWidth = (rightEdge - leftEdge ) / nbins; +// const double secondEdge = leftEdge + binWidth; +// double lastButOneEdge = rightEdge - binWidth; +// Nd4jLong* outputBuffer; +// cudaError_t err = cudaMalloc(&outputBuffer, output.lengthOf() * sizeof(Nd4jLong)); +// if (err != 0) +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot allocate memory for output", err); +// copyBuffers<<<256, 512, 8192, *stream>>>(outputBuffer, output.getSpecialBuffer(), output.getSpecialShapeInfo(), output.lengthOf()); +// histogramFixedWidthKernel<<<256, 512, 8192, *stream>>>(outputBuffer, output.lengthOf(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), input.lengthOf(), leftEdge, binWidth, secondEdge, lastButOneEdge); +// returnBuffers<<<256, 512, 8192, *stream>>>(output.specialBuffer(), outputBuffer, output.specialShapeInfo(), output.lengthOf()); +// //cudaSyncStream(*stream); +// err = cudaFree(outputBuffer); +// if (err != 0) +// throw cuda_exception::build("helpers::histogramFixedWidth: Cannot deallocate memory for output buffer", err); +// output.tickWriteDevice(); +// //#pragma omp parallel for schedule(guided) +// // for(Nd4jLong i = 0; i < input.lengthOf(); ++i) { +// // +// // const T value = input.e(i); +// // +// // if(value < secondEdge) +// //#pragma omp critical +// // output.p(0, output.e(0) + 1); +// // else if(value >= lastButOneEdge) +// //#pragma omp critical +// // output.p(nbins-1, output.e(nbins-1) + 1); +// // else { +// // Nd4jLong currInd = static_cast((value - leftEdge) / binWidth); +// //#pragma omp critical +// // output.p(currInd, output.e(currInd) + 1); +// // } +// // } +// } + +// void histogramFixedWidth(nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output) { +// BUILD_SINGLE_SELECTOR(input.dataType(), histogramFixedWidth_, (context, input, range, output), LIBND4J_TYPES); +// } +// BUILD_SINGLE_TEMPLATE(template void histogramFixedWidth_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& range, NDArray& output), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index fab9577d6..e86cd382a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -30,51 +30,9 @@ namespace nd4j { namespace ops { namespace helpers { - template - inline void __device__ indexSwap(T* arr, Nd4jLong idx1, Nd4jLong idx2) { - T tmp = arr[idx1]; - arr[idx1] = arr[idx2]; - arr[idx2] = tmp; - } -// template -// void reverseArray(nd4j::LaunchContext * context, void* inArr, Nd4jLong *inShapeBuffer, void *result, Nd4jLong *zShapeBuffer, int numOfElemsToReverse = 0); - - ///////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void reverseArrayInplaceKernel(void *input, Nd4jLong *inputShape, Nd4jLong numOfElemsToReverse) { - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - __shared__ Nd4jLong length; - __shared__ int linearStatus; - __shared__ T* inputArr; - if (threadIdx.x == 0) { - length = shape::length(inputShape); - linearStatus = shape::elementWiseStride(inputShape); - inputArr = reinterpret_cast(input); - } - __syncthreads(); - - for (Nd4jLong e = tid; e < numOfElemsToReverse / 2; e += step) { - if (linearStatus == 1) { - auto idx = numOfElemsToReverse - e - 1; - indexSwap(inputArr, e, idx); - } - else if (linearStatus > 1) { - auto idx1 = (numOfElemsToReverse - e - 1) * linearStatus; - Nd4jLong idx2 = e * linearStatus; - indexSwap(inputArr, idx1, idx2); - } - else { - auto inOffset = shape::getIndexOffset(e, inputShape, length); - auto outOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape, length); - indexSwap(inputArr, inOffset, outOffset); - } - } - } - template static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) { - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; __shared__ Nd4jLong length; __shared__ int linearStatus; @@ -93,51 +51,47 @@ namespace helpers { } __syncthreads(); - for (Nd4jLong e = tid; e < length; e += step) { - if (e < numOfElemsToReverse ) { - if (linearStatus == 1) { - auto idx = numOfElemsToReverse - e - 1; - outputArr[idx] = inputArr[e]; - } else if (linearStatus > 1) { - auto idx1 = (numOfElemsToReverse - e - 1) * linearStatus; - Nd4jLong idx2 = e * linearStatus; - outputArr[idx1] = inputArr[idx2]; - } else { - auto inOffset = shape::getIndexOffset(e, inputShape, length); - auto outOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape, length); - outputArr[outOffset] = inputArr[inOffset]; - } - } - else { - if (linearStatus == 1) { - outputArr[e] = inputArr[e]; - } else if (linearStatus > 1) { - auto idx1 = e * linearStatus; - Nd4jLong idx2 = e * linearStatus; - outputArr[idx1] = inputArr[idx2]; - } else { - auto inOffset = shape::getIndexOffset(e, inputShape, length); - auto outOffset = shape::getIndexOffset(e, outputShape, length); - outputArr[outOffset] = inputArr[inOffset]; - } - } + auto odd = numOfElemsToReverse % 2 != 0; + auto limit = numOfElemsToReverse / 2; + + for (Nd4jLong e = tid; e < limit; e += step) { + // we're calculating offsets within input array + auto fOffset = shape::getIndexOffset(e, inputShape, length); + auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape, length); + + // now we're storing input values + auto v1 = inputArr[fOffset]; + auto v2 = inputArr[lOffset]; + + // now we're calculating offsets within output array + auto zfOffset = shape::getIndexOffset(e, outputShape, length); + auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape, length); + + // and saving values to output arrays + outputArr[zfOffset] = v2; + outputArr[zlOffset] = v1; + + //printf("TID: %i; E: %lld; z[%lld], z[%lld] = x[%lld], x[%lld];\n", tid, e, zfOffset, zlOffset, lOffset, fOffset); } - //printf("\n"); + // in case of odd array we'll have to move middle value + if (odd && tid == 0) { + auto xOffset = shape::getIndexOffset(limit, inputShape, length); + auto zOffset = shape::getIndexOffset(limit, outputShape, length); + + outputArr[zOffset] = inputArr[xOffset]; + //printf("TID: %i; E: %lld; z[%lld] = x[%lld];\n", tid, limit, zOffset, xOffset); + } } template - static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int numOfElemsToReverse) { + static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { auto stream = context->getCudaStream(); Nd4jLong numOfReverse = numOfElemsToReverse; if (numOfElemsToReverse == 0) numOfReverse = input->lengthOf(); - if (input == output) { - reverseArrayInplaceKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), numOfReverse); - } - else { - reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); - } + + reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); } @@ -148,11 +102,8 @@ namespace helpers { seqLengths->syncToHost(); auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({output}, {input, seqLengths}); if(input->isVector() || shape::isLikeVector(input->getShapeInfo(), posOfNonUnityDim) || seqLengths->lengthOf() == 1) { int numOfElemsToReverse = seqLengths->e(0); -// printf("Length %d\n", numOfElemsToReverse); -// input->printBuffer("INPUT"); if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim)) output->assign(input); else @@ -168,7 +119,6 @@ namespace helpers { auto inSubArrsSet = input->allTensorsAlongDimension(dimensions); auto outSubArrsSet = output->allTensorsAlongDimension(dimensions); -// #pragma omp parallel for schedule(guided) if(inSubArrsSet->size() > Environment::getInstance()->elementwiseThreshold()) for(int i = 0; i < inSubArrsSet->size(); ++i) { int numOfElemsToReverse = seqLengths->e(i); @@ -189,11 +139,18 @@ namespace helpers { delete inSubArrsSet; delete outSubArrsSet; } - NDArray::registerSpecialUse({output}, {input, seqLengths}); + } void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { + NDArray::prepareSpecialUse({output}, {input, seqLengths}); + + // if op isn't inplace - copy original data into output array + if (output->getSpecialBuffer() != input->getSpecialBuffer()) + output->assign(input); + BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {input, seqLengths}); } ////////////////////////////////////////////////////////////////////////// @@ -221,7 +178,7 @@ namespace helpers { delete listIn; } -BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, int numOfElemsToReverse), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu index ec0d304df..54d350f47 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu @@ -398,10 +398,15 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind const int xRank = indices.rankOf(); std::vector zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); - std::vector yTadDims(xRank); - std::iota(yTadDims.begin(), yTadDims.end(), xRank == 1 ? 0 : xRank); - auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), yTadDims); + int sizeOfUpdDims = xRank; + if(output.rankOf() == updates.rankOf() && indices.isVector()) + sizeOfUpdDims = 1; + + std::vector yTadDims(sizeOfUpdDims); + std::iota(yTadDims.begin(), yTadDims.end(), 0); + + auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims); const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo()); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index 180af41e1..8830f37e7 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -40,19 +40,16 @@ namespace nd4j { static __global__ void segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, void *output, Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - T *x; - __shared__ - T *z; + __shared__ T *val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T *x; + __shared__ T *z; __shared__ int threadsPerSegment, start, finish; + auto segment = blockIdx.x; if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; +// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; +// segment = blockIdx.x / threadsPerSegment; x = reinterpret_cast(input); z = reinterpret_cast(output); extern __shared__ unsigned char shmem[]; @@ -83,19 +80,14 @@ namespace nd4j { unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape, int *starts, int *lengths, Nd4jLong numOfClasses, void *output, Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - T *x; - __shared__ - T *z; - __shared__ - I *y; //int threadsPerSegment, start, finish; + __shared__ T *val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T *x; + __shared__ T *z; + __shared__ I *y; //int threadsPerSegment, start, finish; + auto segment = blockIdx.x; if (threadIdx.x == 0) { - segment = blockIdx.x; x = reinterpret_cast(input); z = reinterpret_cast(output); y = reinterpret_cast(indices); @@ -127,9 +119,10 @@ namespace nd4j { Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) { __shared__ T* val; - __shared__ Nd4jLong len, segment, zIndex, total; + __shared__ Nd4jLong len, zIndex, total; __shared__ T* z; __shared__ int start, finish; + __shared__ I segment; if (threadIdx.x == 0) { segment = indices[blockIdx.x]; // / threadsPerSegment; @@ -143,20 +136,22 @@ namespace nd4j { __syncthreads(); auto idx = blockIdx.x; - if (blockIdx.x <= total) { + if (idx <= total) { auto x = reinterpret_cast(inputBuf) + inputTadOffsets[idx]; if (blockIdx.x == start) { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + //z[zIndex] = x[xIndex]; } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); + if (lengths[segment]) + nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); } } } @@ -168,6 +163,7 @@ namespace nd4j { //int numClasses = output->sizeAt(0); // if input is a vector: (as if in doc sample) //Nd4jLong idx = indices->e(0); + output->assign(-DataTypeUtils::infOrMax()); auto stream = context->getCudaStream(); indices->syncToHost(); Nd4jLong numOfClasses = indices->e(indices->lengthOf() - 1) + 1; @@ -201,7 +197,9 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -209,6 +207,8 @@ namespace nd4j { static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); + output->assign(DataTypeUtils::infOrMax()); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); @@ -240,7 +240,10 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -370,8 +373,10 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // @@ -416,7 +421,9 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index 3f2168da4..19869f646 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -74,14 +74,14 @@ namespace helpers { template static __global__ void unsortedSegmentMeanLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { __shared__ T* val; - __shared__ Nd4jLong xLen, zLen, segment, zIndex; + __shared__ Nd4jLong xLen, zLen, zIndex; __shared__ T* x; __shared__ T* z; __shared__ I* y; //int threadsPerSegment, start, finish; - + auto segment = blockIdx.x;// / if (threadIdx.x == 0) { // threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x;// / threadsPerSegment; +// threadsPerSegment; x = reinterpret_cast(input); z = reinterpret_cast(output); y = reinterpret_cast(indices); @@ -117,12 +117,12 @@ namespace helpers { template static __global__ void segmentMeanTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { __shared__ T* val; - __shared__ Nd4jLong len, segment, zIndex, total; + __shared__ Nd4jLong len, zIndex, total; __shared__ T* z; __shared__ int threadsPerSegment, start, finish; + auto segment = indices[blockIdx.x]; // / threadsPerSegment; if (threadIdx.x == 0) { - segment = indices[blockIdx.x]; // / threadsPerSegment; z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; len = shape::length(inputTads); start = starts[segment]; @@ -139,7 +139,7 @@ namespace helpers { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = T(x[xIndex]/lengths[segment]); + nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment])); } } else { @@ -163,7 +163,7 @@ namespace helpers { classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); - + NDArray::prepareSpecialUse({output}, {input, indices}); dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); @@ -182,11 +182,14 @@ namespace helpers { Nd4jLong* outputTadOffsets = packZ.specialOffsets(); segmentMeanTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); } + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -194,6 +197,7 @@ namespace helpers { static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); @@ -225,8 +229,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -349,8 +355,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // // segmen mean bp main int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // @@ -399,7 +407,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index 0c67b41d5..e5ea2eb91 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -38,19 +38,16 @@ namespace helpers { static __global__ void segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, void *output, Nd4jLong *outputShape) { - __shared__ - T *val; - __shared__ - Nd4jLong xLen, zLen, segment, zIndex; - __shared__ - T *x; - __shared__ - T *z; + __shared__ T *val; + __shared__ Nd4jLong xLen, zLen, zIndex; + __shared__ T *x; + __shared__ T *z; __shared__ int threadsPerSegment, start, finish; + auto segment = blockIdx.x; if (threadIdx.x == 0) { - threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; - segment = blockIdx.x / threadsPerSegment; +// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; +// segment = blockIdx.x / threadsPerSegment; x = reinterpret_cast(input); z = reinterpret_cast(output); extern __shared__ unsigned char shmem[]; @@ -123,12 +120,12 @@ namespace helpers { template static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { __shared__ T* val; - __shared__ Nd4jLong len, segment, zIndex, total; + __shared__ Nd4jLong len, zIndex, total; __shared__ T* z; __shared__ int threadsPerSegment, start, finish; + auto segment = indices[blockIdx.x]; // / threadsPerSegment; if (threadIdx.x == 0) { - segment = indices[blockIdx.x]; // / threadsPerSegment; z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; len = shape::length(inputTads); start = starts[segment]; @@ -145,14 +142,15 @@ namespace helpers { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); +// if (lengths[indices[idx]]) + nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); } } } @@ -165,7 +163,7 @@ namespace helpers { Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); - + output->assign(DataTypeUtils::infOrMax()); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -192,7 +190,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -205,6 +206,7 @@ namespace helpers { NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); + output->assign(DataTypeUtils::infOrMax()); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); @@ -233,8 +235,11 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } template @@ -364,8 +369,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // // segmen min int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } template @@ -409,7 +416,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index 78f21916d..5709a63ea 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -146,14 +146,15 @@ namespace helpers { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); + if (lengths[segment] > 0) + nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); } } } @@ -166,7 +167,7 @@ namespace helpers { Nd4jLong numClasses = indices->e(indices->lengthOf() - 1) + 1; NDArray classesRangesLens = NDArrayFactory::create('c', {numClasses}); NDArray classesRangesBegs = NDArrayFactory::create('c', {numClasses}); - + output->assign(1); classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); @@ -192,7 +193,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -231,8 +234,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -358,8 +363,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // @@ -367,6 +374,7 @@ namespace helpers { template static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); + NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT(), context);//->shapeInfo(), context); unsortedSegmentProdFunctor_(context, input, indices, numOfClasses, &tempRes); NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); @@ -404,7 +412,9 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index 4141cefba..229d41cc9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -146,8 +146,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // template @@ -270,7 +272,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 37dacee09..4b8976f4e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -121,12 +121,12 @@ namespace helpers { template static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { __shared__ T* val; - __shared__ Nd4jLong len, segment, zIndex, total; + __shared__ Nd4jLong len, zIndex, total; __shared__ T* z; - __shared__ int threadsPerSegment, start, finish; + __shared__ int start, finish; if (threadIdx.x == 0) { - segment = indices[blockIdx.x]; // / threadsPerSegment; + auto segment = indices[blockIdx.x]; // / threadsPerSegment; z = reinterpret_cast(outputBuf) + outputTadOffsets[segment]; len = shape::length(inputTads); start = starts[segment]; @@ -143,14 +143,14 @@ namespace helpers { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - z[zIndex] = x[xIndex]; + nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); } } else { for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto xIndex = shape::getIndexOffset(e, inputTads, len); auto zIndex = shape::getIndexOffset(e, outputTads, len); - if (lengths[segment]) + if (lengths[indices[idx]]) nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); } } @@ -190,7 +190,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -229,8 +232,11 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); + output->nullify(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } @@ -343,8 +349,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } template @@ -381,7 +389,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index c07c9adf8..0695119da 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -100,7 +100,7 @@ static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int thr BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& VT, const bool fullUV, const bool calcUV) { +static void svdQR(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* VT, const bool fullUV, const bool calcUV) { // since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f' // we make this function to have deal with 2 valid cases only: @@ -113,59 +113,59 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND // U [m, m] or [m, n] if fullUV = false and m > n // VT [n, n] or [m, n] if fullUV = false and m < n - if(A.rankOf() != 2) + if(A->rankOf() != 2) throw std::runtime_error("svdQR: rank of A array is not equal 2 !"); - auto m = A.sizeAt(0); - auto n = A.sizeAt(1); + auto m = A->sizeAt(0); + auto n = A->sizeAt(1); const int minDim = m < n ? m : n; - const char orderA = A.ordering(); + const char orderA = A->ordering(); if(m < n) throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !"); - if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S)) + if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S)) throw std::runtime_error("svdQR: wrong shape of S array !"); if(calcUV) { - if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U)) + if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdQR: wrong shape of U array !"); - else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U)) + else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdQR: wrong shape of U array !"); - if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&VT)) + if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(VT)) throw std::runtime_error("svdQR: wrong shape of VT array !"); - else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(&VT)) + else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(VT)) throw std::runtime_error("svdQR: wrong shape of VT array !"); } - NDArray* pA = const_cast(&A); - NDArray* pS = &S; - NDArray* pU = &U; - NDArray* pVT = &VT; + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pVT = VT; std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A.dup('f'); + pA = A->dup('f'); toDelete.push_back(pA); } - if(S.ews() != 1) { - pS = S.dup('f'); + if(S->ews() != 1) { + pS = S->dup('f'); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U.dup('f'); + pU = U->dup('f'); toDelete.push_back(pU); } if(pVT->ews() != 1 || pVT->ordering() == 'c') { - pVT = VT.dup('f'); + pVT = VT->dup('f'); toDelete.push_back(pVT); } } @@ -183,9 +183,9 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND // query working space of SVD int lwork = 0; - if(A.dataType() == DataType::DOUBLE) + if(A->dataType() == DataType::DOUBLE) status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork); - else if(A.dataType() == DataType::FLOAT32) + else if(A->dataType() == DataType::FLOAT32) status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork); else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -195,7 +195,7 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND // allocate memory for dWork void* dWork = nullptr; - cudaError_t status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); + cudaError_t status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); if(status2 != cudaSuccess) throw cuda_exception::build("svdQR: cuda failed !", status2); @@ -226,11 +226,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND NDArray::prepareSpecialUse({pS, pU, pVT}, {pA}); // choose appropriate cuda gemm api depending on data types - if(A.dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pVT->getSpecialBuffer()), ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + if(A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } - else if(A.dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pVT->getSpecialBuffer()), ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + else if(A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -242,11 +242,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND NDArray::registerSpecialUse({pS, pU, pVT}, {pA}); - S.assign(pS); + S->assign(pS); if(calcUV) { - U.assign(pU); - VT.assign(pVT); + U->assign(pU); + VT->assign(pVT); } for (int i = toDelete.size() - 1; i >= 0; --i) @@ -266,62 +266,62 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND } ////////////////////////////////////////////////////////////////////////// -static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) { +static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { // A [m, n] // S [n] // U [m, m] or [m, n] if fullUV = false and m > n // V [n, n] or [n, m] if fullUV = false and m < n - if(A.rankOf() != 2) + if(A->rankOf() != 2) throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); - auto m = A.sizeAt(0); - auto n = A.sizeAt(1); + auto m = A->sizeAt(0); + auto n = A->sizeAt(1); const int minDim = m < n ? m : n; - if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S)) + if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S)) throw std::runtime_error("svdJcb: wrong shape of S array !"); if(calcUV) { - if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U)) + if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdJcb: wrong shape of U array !"); - else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U)) + else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdJcb: wrong shape of U array !"); - if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&V)) + if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(V)) throw std::runtime_error("svdJcb: wrong shape of V array !"); - else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(&V)) + else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(V)) throw std::runtime_error("svdJcb: wrong shape of V array !"); } - NDArray* pA = const_cast(&A); - NDArray* pS = &S; - NDArray* pU = &U; - NDArray* pV = &V; + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pV = V; std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A.dup('f'); + pA = A->dup('f'); toDelete.push_back(pA); } - if(S.ews() != 1) { - pS = S.dup('f'); + if(S->ews() != 1) { + pS = S->dup('f'); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U.dup('f'); + pU = U->dup('f'); toDelete.push_back(pU); } if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V.dup('f'); + pV = V->dup('f'); toDelete.push_back(pV); } } @@ -362,10 +362,10 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N // query working space of SVD int lwork = 0; - if(A.dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams); - else if(A.dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams); + if(A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); + else if(A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -374,7 +374,7 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N // allocate memory dWork void* dWork = nullptr; - auto status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); + auto status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); if(status2 != cudaSuccess) throw cuda_exception::build("svdJcb: cuda failed !", status2); @@ -383,11 +383,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); // choose appropriate cuda gemm api depending on data types - if(A.dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + if(A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } - else if(A.dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + else if(A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -399,11 +399,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - S.assign(pS); + S->assign(pS); if(calcUV) { - U.assign(pU); - V.assign(pV); + U->assign(pU); + V->assign(pV); } for (int i = toDelete.size() - 1; i >= 0; --i) @@ -422,67 +422,67 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N } ////////////////////////////////////////////////////////////////////////// -static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) { +static void svdBatched(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { // A [..., m, n] // S [..., n] // U [..., m, m] or [..., m, n] if fullUV = false and m > n // V [..., n, n] or [..., n, m] if fullUV = false and m < n - auto m = A.sizeAt(-2); - auto n = A.sizeAt(-1); + auto m = A->sizeAt(-2); + auto n = A->sizeAt(-1); const int minDim = m < n ? m : n; - const Nd4jLong bS = A.lengthOf() / (m * n); + const Nd4jLong bS = A->lengthOf() / (m * n); if(m > 32 || n > 32) throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !"); - if(minDim != S.sizeAt(-1)) + if(minDim != S->sizeAt(-1)) throw std::runtime_error("svdBatched: wrong shape of S array !"); if(calcUV) { - if(U.sizeAt(-2) != m) + if(U->sizeAt(-2) != m) throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U.sizeAt(-1) != (fullUV ? m : minDim)) + if(U->sizeAt(-1) != (fullUV ? m : minDim)) throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U.lengthOf() / (U.sizeAt(-2) * U.sizeAt(-1)) != bS) + if(U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS) throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(V.sizeAt(-2) != n) + if(V->sizeAt(-2) != n) throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V.sizeAt(-1) != (fullUV ? n : minDim)) + if(V->sizeAt(-1) != (fullUV ? n : minDim)) throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V.lengthOf() / (V.sizeAt(-2) * V.sizeAt(-1)) != bS) + if(V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS) throw std::runtime_error("svdBatched: wrong shape of V array !"); } - NDArray* pA = const_cast(&A); - NDArray* pS = &S; - NDArray* pU = &U; - NDArray* pV = &V; + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pV = V; std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A.dup('f'); + pA = A->dup('f'); toDelete.push_back(pA); } - if(S.ews() != 1) { - pS = S.dup('f'); + if(S->ews() != 1) { + pS = S->dup('f'); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U.dup('f'); + pU = U->dup('f'); toDelete.push_back(pU); } if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V.dup('f'); + pV = V->dup('f'); toDelete.push_back(pV); } } @@ -532,10 +532,10 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& // query working space of SVD int lwork = 0; - if(A.dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS); - else if(A.dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS); + if(A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); + else if(A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); else throw std::invalid_argument("svdBatched: given data type is unsupported !"); @@ -544,7 +544,7 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& // allocate memory dWork void* dWork = nullptr; - status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); + status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); if(status2 != cudaSuccess) throw cuda_exception::build("svdBatched: cuda failed !", status2); status2 = cudaDeviceSynchronize(); @@ -556,11 +556,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); // choose appropriate cuda gemm api depending on data types - if(A.dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + if(A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); } - else if(A.dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + else if(A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); } else throw std::invalid_argument("svdBatched: given data type is unsupported !"); @@ -572,11 +572,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - S.assign(pS); + S->assign(pS); if(calcUV) { - U.assign(pU); - V.assign(pV); + U->assign(pU); + V->assign(pV); } for (int i = toDelete.size() - 1; i >= 0; --i) @@ -602,9 +602,11 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vectortranspose(); NDArray* V = outArrs[2]; + NDArray::prepareSpecialUse({S, U, V}, {x}); + if(x->rankOf() == 2) { - // svdQR(context, *x, *S, *U, VT, fullUV, calcUV); - svdJcb(context, *x, *S, *U, *V, fullUV, calcUV); + // svdQR(context, x, S, U, VT, fullUV, calcUV); + svdJcb(context, x, S, U, V, fullUV, calcUV); } else { @@ -621,7 +623,7 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vectorsize(); ++i) - svdJcb(context, *tadsX->at(i), *tadsS->at(i), calcUV ? *tadsU->at(i) : *S, calcUV ? *tadsV->at(i) : *S, fullUV, calcUV); + svdJcb(context, tadsX->at(i), tadsS->at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); delete tadsX; delete tadsS; @@ -631,6 +633,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector -#include + +#ifndef SD_HAMMING_H +#define SD_HAMMING_H namespace nd4j { -namespace ops { -namespace helpers { + namespace ops { + namespace helpers { + void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output); + } + } +} - Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output); -} -} -} -#endif +#endif //DEV_TESTS_HAMMING_H diff --git a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp index baa08dad9..c840f6960 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp @@ -42,10 +42,11 @@ namespace helpers { Nd4jLong listDiffCount(nd4j::LaunchContext * context, NDArray* values, NDArray* keep) { auto xType = values->dataType(); - values->syncToHost(); - keep->syncToHost(); + NDArray::preparePrimaryUse({},{values, keep}); BUILD_SINGLE_SELECTOR(xType, return listDiffCount_, (values, keep), LIBND4J_TYPES); + + NDArray::registerPrimaryUse({},{values, keep}); } BUILD_SINGLE_TEMPLATE(template Nd4jLong listDiffCount_, (NDArray* values, NDArray* keep);, LIBND4J_TYPES); @@ -97,16 +98,7 @@ namespace helpers { int listDiffFunctor(nd4j::LaunchContext * context, NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2) { auto xType = values->dataType(); - values->syncToHost(); - - if (keep != nullptr) - keep->syncToHost(); - - if (output1 != nullptr) - output1->syncToHost(); - - if (output2 != nullptr) - output2->syncToHost(); + NDArray::preparePrimaryUse({output1, output2}, {values, keep}); int result = 0; @@ -118,14 +110,7 @@ namespace helpers { throw std::runtime_error("ListDiff: Only integer and floating point data types are supported"); } - if (keep != nullptr) - keep->syncToDevice(); - - if (output1 != nullptr) - output1->syncToDevice(); - - if (output2 != nullptr) - output2->syncToDevice(); + NDArray::registerPrimaryUse({output1, output2}, {values, keep}); return result; } diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 5d29ed826..b313acd9c 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -19,7 +19,6 @@ // #include -#include #include #include #include @@ -190,32 +189,6 @@ namespace nd4j { auto outSha = this->calculateOutputShape(&inSha, ctx); results = outSha->size(); - // we must "validate" our output shapes - /* - for (int e = 0; e < results; e++) { - auto ptr = outSha->at(e); - - // checking for the same pointer used twice - for (int i = 0; i < results; i++){ - if (i == e) - continue; - - auto com = outSha->at(i); - - if (ptr == com) - throw std::runtime_error("ShapeFunction returned same shape instance twice [" + *_descriptor->getOpName() + "]"); - } - - // checking for input pointer returned back - for (int i = 0; i < inSha.size(); i++){ - auto com = inSha.at(i); - - if (ptr == com) - throw std::runtime_error("ShapeFunction returned input shape instance as output [" + *_descriptor->getOpName() + "]"); - } - } - */ - // optionally saving shapeTime if (Environment::getInstance()->isProfiling() && node != nullptr) { shapeEnd = std::chrono::system_clock::now(); @@ -355,75 +328,145 @@ namespace nd4j { // rolling over inputs first int cnt = 0, inT = 0; std::vector inputTypes(block.width()); - for (auto &p: *(block.inputs())) { - auto var = block.variable(p); - - // we're not checking validity, if ANY types were explicitly allowed - //if (block.dataType(cnt) == nd4j::DataType::ANY) - // continue; - - // only validating non-null variables - if (var != nullptr && var->hasNDArray()) { - auto array = var->getNDArray(); - + if (block.isFastPath()) { + for (auto array: block.fastpath_in()) { inputTypes[inT++] = array->dataType(); if (!_descriptor->checkInputMatch(cnt, array->dataType())) { auto ctype = DataTypeUtils::asString(array->dataType()); - nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), cnt, ctype.c_str()); + nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), cnt, ctype.c_str()); return ND4J_STATUS_BAD_ARGUMENTS; } + cnt++; } + } else { + for (auto &p: *(block.inputs())) { + auto var = block.variable(p); - cnt++; - } - - // checking optionally available outputs - auto varSpace = block.getVariableSpace(); - for (int index = 0; index < DataTypeUtils::max(); index++) { - if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) { - auto var = block.variable(block.nodeId(), index); + // we're not checking validity, if ANY types were explicitly allowed + //if (block.dataType(cnt) == nd4j::DataType::ANY) + // continue; // only validating non-null variables if (var != nullptr && var->hasNDArray()) { auto array = var->getNDArray(); - auto cType = array->dataType(); - if (_descriptor->isSameMode()) { - - if (index >= block.width()) { - auto iv = block.variable(0); - - if (iv->getNDArray()->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } else { - // for same mode, output type must be the same as input type - auto iv = block.variable(index); - - if (iv->getNDArray()->dataType() != cType) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - } - } else if (_descriptor->isInherit(index)) { - // in inherit mode, output type must be the same as one of input types - if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", _descriptor->getOpName()->data(), index, t.c_str()); - return ND4J_STATUS_BAD_ARGUMENTS; - } - - } else if (!_descriptor->checkOutputMatch(index, cType)) { - auto t = DataTypeUtils::asString(cType); - nd4j_printf("Op [%s] failed check for output [%i], DataType: [%i];\n", _descriptor->getOpName()->data(), index, t.c_str()); + inputTypes[inT++] = array->dataType(); + if (!_descriptor->checkInputMatch(cnt, array->dataType())) { + auto ctype = DataTypeUtils::asString(array->dataType()); + nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), cnt, ctype.c_str()); return ND4J_STATUS_BAD_ARGUMENTS; } } - } else - break; + + cnt++; + } + } + + if (block.isFastPath()) { + int index = 0; + for (auto array: block.fastpath_out()) { + auto cType = array->dataType(); + + if (_descriptor->isSameMode()) { + + if (index >= block.width()) { + if (block.fastpath_in().size() == 0) + continue; + + auto ia = block.fastpath_in()[0]; + + if (ia->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + } else { + // for same mode, output type must be the same as input type + auto ia = block.fastpath_in()[index]; + + if (ia->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + } + } else if (_descriptor->isInherit(index)) { + // in inherit mode, output type must be the same as one of input types + if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + + } else if (!_descriptor->checkOutputMatch(index, cType)) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + index++; + } + } else { + // checking optionally available outputs + auto varSpace = block.getVariableSpace(); + for (int index = 0; index < DataTypeUtils::max(); index++) { + if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) { + auto var = block.variable(block.nodeId(), index); + + // only validating non-null variables + if (var != nullptr && var->hasNDArray()) { + auto array = var->getNDArray(); + auto cType = array->dataType(); + + if (_descriptor->isSameMode()) { + + if (index >= block.width()) { + if (block.width() == 0) + continue; + + auto iv = block.variable(0); + + if (iv->getNDArray()->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + } else { + // for same mode, output type must be the same as input type + auto iv = block.variable(index); + + if (iv->getNDArray()->dataType() != cType) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + } + } else if (_descriptor->isInherit(index)) { + // in inherit mode, output type must be the same as one of input types + if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + + } else if (!_descriptor->checkOutputMatch(index, cType)) { + auto t = DataTypeUtils::asString(cType); + nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n", + _descriptor->getOpName()->data(), index, t.c_str()); + return ND4J_STATUS_BAD_ARGUMENTS; + } + } + } else + break; + } } diff --git a/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp similarity index 67% rename from datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java rename to libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp index 8aeae58a5..607572b59 100644 --- a/datavec/datavec-spark/src/main/java/org/datavec/spark/functions/FlatMapFunctionAdapter.java +++ b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp @@ -14,16 +14,14 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.datavec.spark.functions; +// +// @author raver119@gmail.com +// +#include -import java.io.Serializable; - -/** - * - * A function that returns zero or more output records from each input record. - * - * Adapter for Spark interface in order to freeze interface changes between spark versions - */ -public interface FlatMapFunctionAdapter extends Serializable { - Iterable call(T t) throws Exception; +namespace nd4j { + BroadcastIntOpsTuple BroadcastIntOpsTuple::custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) { + BroadcastIntOpsTuple t(scalar, pairwise, broadcast); + return t; + } } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index 38122f985..e4fef2c3c 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -660,6 +660,98 @@ namespace simdOps { } }; + template + class IntOr { + public: + + op_def static X op(X d1, X d2) { + return d2 | d1; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class IntAnd { + public: + + op_def static X op(X d1, X d2) { + return d2 & d1; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class IntXor { + public: + + op_def static X op(X d1, X d2) { + return d2 ^ d1; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class ShiftLeft { + public: + + op_def static X op(X d1, X d2) { + return d1 << d2; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class ShiftRight { + public: + + op_def static X op(X d1, X d2) { + return d1 >> d2; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class CyclicShiftLeft { + public: + + op_def static X op(X d1, X d2) { + return d1 << d2 | d1 >> ((sizeof(X) * 8) - d2); + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class CyclicShiftRight { + public: + + op_def static X op(X d1, X d2) { + return d1 >> d2 | d1 << ((sizeof(X) * 8) - d2); + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template class Or { public: @@ -3746,7 +3838,7 @@ namespace simdOps { }; - template + template class IndexAbsoluteMax { public: static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { @@ -3799,7 +3891,7 @@ namespace simdOps { } }; - template + template class FirstIndex { public: static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { @@ -3861,7 +3953,7 @@ namespace simdOps { }; - template + template class LastIndex { public: static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { @@ -3920,7 +4012,7 @@ namespace simdOps { }; - template + template class IndexMax { public: @@ -3974,7 +4066,7 @@ namespace simdOps { }; - template + template class IndexAbsoluteMin { public: static _CUDA_HD inline functions::indexreduce::IndexValue op( @@ -4030,7 +4122,7 @@ namespace simdOps { }; - template + template class IndexMin { public: static _CUDA_HD inline functions::indexreduce::IndexValue op( diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu index 19f107ea4..027dcdd42 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests2.cu @@ -272,6 +272,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_12) { NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, nd4j::DataType::FLOAT32); nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); ASSERT_TRUE(c.equalsTo(&exp)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 1d4bf7338..82ed21709 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -910,7 +910,31 @@ TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) { auto *out = results->at(0); ASSERT_TRUE(exp.isSameShape(out)); - out->printBuffer("5HIST"); + // out->printBuffer("5HIST"); + ASSERT_TRUE(exp.equalsTo(out)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) { + + auto input = NDArrayFactory::create('c', {7},{0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9}); + auto range = NDArrayFactory::create('c', {2}, {0, 1}); + auto bins = NDArrayFactory::create(5); + + auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); + + nd4j::ops::histogram_fixed_width op; + auto results = op.execute({&input, &range, &bins}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto out = results->at(0); + // out->printShapeInfo(); + // out->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(out)); ASSERT_TRUE(exp.equalsTo(out)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 8c484268e..87ac417be 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -623,12 +623,13 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) { TEST_F(DeclarableOpsTests13, shift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(32); e.assign(512); nd4j::ops::shift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -640,12 +641,13 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) { TEST_F(DeclarableOpsTests13, rshift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(512); e.assign(32); nd4j::ops::rshift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -657,12 +659,13 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) { TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(32); e.assign(512); nd4j::ops::cyclic_shift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -674,12 +677,107 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(512); e.assign(32); nd4j::ops::cyclic_rshift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, shift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + nd4j::ops::shift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, rshift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); + + nd4j::ops::rshift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + nd4j::ops::cyclic_shift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); + + nd4j::ops::cyclic_rshift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} +TEST_F(DeclarableOpsTests13, shift_bits_3) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + nd4j::ops::shift_bits op; + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index df1421d71..65cb470a7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -245,27 +245,47 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_3) { TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { auto x = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto g = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto b = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + auto g = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + auto b = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); nd4j::ops::layer_norm op; - auto result = op.execute({&x, &g, &b}, {}, {0}, {}); + auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result->status()); delete result; } TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { auto x = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto g = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto b = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + auto g = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + auto b = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); auto eps = NDArrayFactory::create('c', {1, 5}, {0., 0., 0., 0., 0.}); nd4j::ops::layer_norm_bp op; - auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {}); + auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false}); ASSERT_EQ(Status::OK(), result->status()); delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_2) { + + NDArray x('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32); + NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, nd4j::DataType::FLOAT32); + NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, nd4j::DataType::FLOAT32); + NDArray gradO('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32); + + NDArray gradI('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32); + NDArray gradG('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradB('c', {4}, nd4j::DataType::FLOAT32); + + x.linspace(-20, 0.5); + gradO.linspace(-4, 0.05); + + nd4j::ops::layer_norm_bp op; + auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, {}, {1,2,3}, {true}); + ASSERT_EQ(Status::OK(), status); +} + TEST_F(DeclarableOpsTests15, test_hashCode_1) { auto x = NDArrayFactory::create('c', {10}); auto y = NDArrayFactory::create('c', {10}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index a23d5421e..adbff7f83 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -39,7 +39,7 @@ public: } }; -TEST_F(DeclarableOpsTests16, test_scatter_update_119) { +TEST_F(DeclarableOpsTests16, scatter_upd_1) { auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); auto y = NDArrayFactory::create(0); auto w = NDArrayFactory::create(3.0f); @@ -56,6 +56,27 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) { delete result; } +TEST_F(DeclarableOpsTests16, scatter_upd_2) { + + NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); + NDArray indices('c', {2}, {2,5}, nd4j::DataType::INT32); + NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); + NDArray e('c', {10, 3}, {1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30}, nd4j::DataType::FLOAT32); + + x.linspace(1); + + nd4j::ops::scatter_upd op; + auto result = op.execute({&x, &indices, &updates}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + + TEST_F(DeclarableOpsTests16, test_size_dtype_1) { auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); auto z = NDArrayFactory::create(0.0f); @@ -66,4 +87,50 @@ TEST_F(DeclarableOpsTests16, test_size_dtype_1) { ASSERT_EQ(Status::OK(), status); ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_1) { + auto z = NDArrayFactory::empty(); + + nd4j::ops::noop op; + auto status = op.execute({}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_2) { + auto z = NDArrayFactory::empty(); + + Context ctx(1); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::noop op; + auto status = op.execute(&ctx); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_svd_1) { + auto x = NDArrayFactory::create('c', {3, 3}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f}); + auto z = NDArrayFactory::create('c', {3}); + + nd4j::ops::svd op; + auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { + auto x = NDArrayFactory::create({37, 37, 37}); + auto y = NDArrayFactory::create({8723, 8723, 8723}); + auto e = NDArrayFactory::create(18); + + nd4j::ops::bits_hamming_distance op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index d786a68cb..3b4ff6cd0 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -2757,10 +2757,18 @@ TEST_F(DeclarableOpsTests3, svd_test10) { /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test11) { - auto x = NDArrayFactory::create('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.}); - auto expS = NDArrayFactory::create('c', {3}); - auto expU = NDArrayFactory::create('c', {3,3}); - auto expV = NDArrayFactory::create('c', {3,3}); + NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555, + 0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234, + 0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695}); + NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571}); + NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01, + 7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01, + 7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01, + 8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01}); + NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763, + -0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413, + -0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232, + -0.43596, 0.83108, -0.34531}); nd4j::ops::svd op; auto results = op.execute({&x}, {}, {0, 1, 16}); @@ -2775,6 +2783,10 @@ TEST_F(DeclarableOpsTests3, svd_test11) { ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + delete results; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index b596ebcd5..3af53bad0 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -639,10 +639,10 @@ TEST_F(DeclarableOpsTests5, eye_test2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test3) { - auto expected = NDArrayFactory::create('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); + auto expected = NDArrayFactory::create('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); nd4j::ops::eye op; - auto results = op.execute({}, {}, {-99, 3, 4, 2}); + auto results = op.execute({}, {9 /*int*/}, {-99, 3, 4, 2}); auto output = results->at(0); // output->printIndexedBuffer("Output eye"); @@ -656,10 +656,10 @@ TEST_F(DeclarableOpsTests5, eye_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test4) { - auto expected = NDArrayFactory::create('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); + auto expected = NDArrayFactory::create('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); nd4j::ops::eye op; - auto results = op.execute({}, {}, {-99, 3, 4, 2, 2}); + auto results = op.execute({}, {6/*double*/}, {-99, 3, 4, 2, 2}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -815,6 +815,23 @@ TEST_F(DeclarableOpsTests5, gatherNd_test7) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test8) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto e = NDArrayFactory::create('c', {2}, {1., 4.}); + + nd4j::ops::gather_nd op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { @@ -825,9 +842,13 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {2, 1}); + ASSERT_EQ(Status::OK(), results->status()); + auto output = results->at(0); - ASSERT_EQ(Status::OK(), results->status()); + exp.printIndexedBuffer("E"); + output->printIndexedBuffer("O"); + ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1069,6 +1090,23 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test14) { + auto input = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + auto lengths = NDArrayFactory::create('c', {8}, {7, 2, 3, 5, 2, 1, 6, 4}); + auto e = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + + nd4j::ops::reverse_sequence op; + auto results = op.execute({&input, &lengths}, {}, {1, 0}); + ASSERT_EQ(Status::OK(), results->status()); + + auto z = results->at(0); + + ASSERT_EQ(e, *z); + + delete results; +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Test_TopK_0) { auto x = NDArrayFactory::create('c', {2, 6}, {1.0, 1.0, 1.0, 1.0, 11.0, 3.0, 1.0, 1.0, 1.0, 14.0, 5.0, 6.0}); @@ -1882,6 +1920,50 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { delete result; } +/* @Test + public void testDynamicPartition(){ + INDArray data = Nd4j.createFromArray(2, 1, 2, 0); + INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); + INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) + .addIntegerArguments(3) //3 partitions + .addInputs(data, partitions).build()); + + INDArray exp0 = Nd4j.createFromArray(2, 0); + INDArray exp1 = Nd4j.createFromArray(2); + INDArray exp2 = Nd4j.createFromArray(1); + + assertEquals(exp0, out[0]); //Usually just gives [0,0] + assertEquals(exp1, out[1]); + assertEquals(exp2, out[2]); + }*/ +TEST_F(DeclarableOpsTests5, DynamicPartition_01) { + + auto x = NDArrayFactory::create({2,1,2,0}); + + auto y = NDArrayFactory::create({0,2,1,0}); + + int numPartition = 3; + std::vector exp( { NDArrayFactory::create('c', {2}, {2, 0}), + NDArrayFactory::create('c', {1}, {2}), + NDArrayFactory::create('c', {1}, {1})}); + + nd4j::ops::dynamic_partition op; + auto result = op.execute({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + ASSERT_EQ(result->size(), numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result->size(); e++) { + auto output = result->at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } + + delete result; +} TEST_F(DeclarableOpsTests5, DynamicPartition_1) { @@ -1993,6 +2075,38 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) { delete result; } +TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::empty(); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + + auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); + auto d1 = NDArrayFactory::empty(); + auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + + nd4j::ops::dynamic_stitch op; + auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + +TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::create('c', {0}); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + + auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); + auto d1 = NDArrayFactory::create('c', {0, 5}); + auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + + nd4j::ops::dynamic_stitch op; + auto result = op.execute({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_1) { @@ -2275,8 +2389,8 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { - auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); - auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 62d297a50..a5e808867 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -1086,13 +1086,11 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) { auto y = NDArrayFactory::create({ 2, 1, 2}); -// ------------------------------------ - auto exp = NDArrayFactory::create({2, 2, 2}); nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32); + auto res = op.execute({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1107,7 +1105,6 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) { auto y = NDArrayFactory::create({2, 1, 2}); -// ------------------------------------ auto exp = NDArrayFactory::create({2, 2, 2}); nd4j::ops::broadcast_dynamic_shape op; @@ -1122,17 +1119,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) { ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) { - auto x = NDArrayFactory::create( {2, 2, 2} ); + auto x = NDArrayFactory::create( {2, 2, 2} ); - auto y = NDArrayFactory::create({ 2, 1}); + auto y = NDArrayFactory::create({2, 1}); -// ------------------------------------ - - auto exp = NDArrayFactory::create({2, 2, 2}); + auto exp = NDArrayFactory::create({2, 2, 2}); nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); + auto res = op.execute({&x, &y}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, res->status()); ASSERT_TRUE(exp.equalsTo(res->at(0))); @@ -1145,9 +1140,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) { auto x = NDArrayFactory::create( {2, 1} ); - auto y = NDArrayFactory::create('c', {1}, { 4,}); - -// ------------------------------------ + auto y = NDArrayFactory::create('c', {1}, {4}); auto exp = NDArrayFactory::create({2, 4}); @@ -1161,49 +1154,7 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) { delete res; } -///////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_5) { - auto x = NDArrayFactory::create({2, 2, 2}); - - auto y = NDArrayFactory::create({2, 2}); - -// ------------------------------------ - - auto exp = NDArrayFactory::create({2, 2, 2}); - - nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); - - ASSERT_EQ(ND4J_STATUS_OK, res->status()); -// res->at(0)->printIndexedBuffer("Output"); -// exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(res->at(0))); - - delete res; -} - -///////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_5) { - - auto x = NDArrayFactory::create({2, 1, 2}); - - auto y = NDArrayFactory::create({2, 2, 4}); - -// ------------------------------------ - - auto exp = NDArrayFactory::create({2, 2, 4}); - - nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); - - ASSERT_EQ(ND4J_STATUS_OK, res->status()); - // res->at(0)->printIndexedBuffer("Output SGO 5"); -// exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(res->at(0))); - - delete res; -} ///////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) { @@ -1211,16 +1162,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) { auto y = NDArrayFactory::create({2, 2, 4}); -// ------------------------------------ - auto exp = NDArrayFactory::create({2, 2, 4}); nd4j::ops::broadcast_dynamic_shape op; auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); ASSERT_EQ(ND4J_STATUS_OK, res->status()); - // res->at(0)->printIndexedBuffer("Output SGO 6"); -// exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; @@ -1233,16 +1180,12 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) { auto y = NDArrayFactory::create({2, 4, 1}); -// ------------------------------------ - auto exp = NDArrayFactory::create({2, 4, 3}); nd4j::ops::broadcast_dynamic_shape op; auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64); ASSERT_EQ(ND4J_STATUS_OK, res->status()); - // res->at(0)->printIndexedBuffer("Output SGO 7"); -// exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(res->at(0))); delete res; @@ -1255,19 +1198,15 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) { auto y = NDArrayFactory::create('c', {1}, {4}); -// ------------------------------------ + auto z = NDArrayFactory::create('c', {1}); auto exp = NDArrayFactory::create('c', {1}, {4}); nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32); + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, res->status()); -// res->at(0)->printIndexedBuffer("Output SGO 8"); -// exp.printIndexedBuffer("Expect"); - ASSERT_TRUE(exp.equalsTo(res->at(0))); - - delete res; + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.equalsTo(z)); } ///////////////////////////////////////////////////////////////////////////////// @@ -1277,19 +1216,16 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) { auto y = NDArrayFactory::create('c', {1}, {1}); -// ------------------------------------ + auto z = NDArrayFactory::create('c', {2}); - auto exp = NDArrayFactory::create('c', {2}, {2,2}); + auto exp = NDArrayFactory::create('c', {2}, {2,2}); nd4j::ops::broadcast_dynamic_shape op; - auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32); + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); - ASSERT_EQ(ND4J_STATUS_OK, res->status()); - res->at(0)->printIndexedBuffer("Output SGO 9"); - exp.printIndexedBuffer("Expect9"); - ASSERT_TRUE(exp.equalsTo(res->at(0))); + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(exp.equalsTo(z)); - delete res; } //////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 2e1dab1a3..c80d75372 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -1423,6 +1423,61 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_2) { delete result; } +TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { + auto x = NDArrayFactory::create('c', {6, 3}, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + + nd4j::ops::segment_mean op; + + auto result = op.execute({&x, &idx}, {}, {}); + ASSERT_EQ(result->status(), Status::OK()); + ASSERT_EQ(result->size(), 1); + exp.printIndexedBuffer("Expect Mean"); + result->at(0)->printIndexedBuffer("Output Mean"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result->at(0))); + + delete result; +} + +TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { + auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + + nd4j::ops::segment_mean op; + x.linspace(1.); + auto result = op.execute({&x, &idx}, {}, {}); + ASSERT_EQ(result->status(), Status::OK()); + ASSERT_EQ(result->size(), 1); + exp.printIndexedBuffer("Expect Mean"); + result->at(0)->printIndexedBuffer("Output Mean"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result->at(0))); + + delete result; +} + +TEST_F(DeclarableOpsTests7, TestSegmentMean_022) { + auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto z = NDArrayFactory::create('c', {3, 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + + nd4j::ops::segment_mean op; + x.linspace(1.); + auto result = op.execute({&x, &idx}, {&z}, {}, {}, {}, false, nd4j::DataType::FLOAT32); + ASSERT_EQ(result, Status::OK()); + + exp.printIndexedBuffer("Expect Mean"); + z.printIndexedBuffer("Output Mean"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(z)); + +// delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index b9445cc70..baba901bf 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -207,6 +207,46 @@ TEST_F(EmptyTests, Test_dup_1) { delete dup; } +TEST_F(EmptyTests, test_empty_scatter_1) { + auto x = NDArrayFactory::create('c', {5}); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); + + x.linspace(1.0f); + + nd4j::ops::scatter_upd op; + auto result = op.execute({&x, &indices, &updates}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(x, *z); + + delete result; +} + +TEST_F(EmptyTests, test_empty_scatter_2) { + auto x = NDArrayFactory::create('c', {5}); + auto z = NDArrayFactory::create('c', {5}); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); + + x.linspace(1.0f); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo()); + ctx.setInputArray(2, updates.buffer(), updates.shapeInfo(), updates.specialBuffer(), updates.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + bool args[] = {true}; + ctx.setBArguments(args, 1); + + nd4j::ops::scatter_upd op; + auto result = op.execute(&ctx); + ASSERT_EQ(Status::OK(), result); + + ASSERT_EQ(x, z); +} + TEST_F(EmptyTests, test_shaped_empty_1) { auto empty = NDArrayFactory::create('c', {2, 0, 3}); std::vector shape = {2, 0, 3}; diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index ef3710371..d0d67000b 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -400,6 +400,32 @@ TEST_F(JavaInteropTests, Test_Synonyms_3) { ASSERT_EQ(nameRef, name); } +TEST_F(JavaInteropTests, Test_FastPath_Validation_1) { + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); +} + +TEST_F(JavaInteropTests, Test_FastPath_Validation_2) { + auto x = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); +} + /* TEST_F(JavaInteropTests, test_avgpooling_edge_1) { int inOutH = 35; diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index 7b9e788f7..4ab884d28 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -2261,304 +2261,4 @@ TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) { ASSERT_TRUE(x->isEmpty()); delete x; -} - -// printCudaGlobal<<<1,1,0,*stream>>>(dX, 6); -// printCudaGlobal<<<1,1,0,*stream>>>(dXShapeInfo, 8); -// printCudaGlobal<<<1,1,0,*stream>>>(dZ, 2); -// printCudaGlobal<<<1,1,0,*stream>>>(dZShapeInfo, 6); -// printCudaGlobal<<<1,1,0,*stream>>>(dimension, 1); -// printCudaGlobal<<<1,1,0,*stream>>>(tadShapeInfo, 6); -// printCudaGlobal<<<1,1,0,*stream>>>(tadOffsets, 2); -// cudaStreamSynchronize(*stream); - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) { - - auto x = NDArrayFactory::create('c', {5,2}, {0,1,2,3,4,5,6,7,8,9}); - x.syncToHost(); - auto z = NDArrayFactory::create('c', {5, 8}); - z.syncToHost(); - - std::vector buffers(4); - std::vector shapes(4); - std::vector hostShapes(4); - - for (size_t i = 0; i < buffers.size(); i++) { - buffers[i] = x.specialBuffer(); - shapes[i] = x.specialShapeInfo(); - hostShapes[i] = x.shapeInfo(); - } - Nd4jPointer extra[2]; - extra[1] = x.getContext()->getCudaStream(); - ::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) { - - auto x = NDArrayFactory::create('c', {5,2}, {0,1,2,3,4,5,6,7,8,9}); - auto z = NDArrayFactory::create('f', {5, 8}); - - std::vector buffers(4); - std::vector shapes(4); - std::vector hostShapes(4); - - x.syncToHost(); - z.syncToHost(); - - for (size_t i = 0; i < buffers.size(); i++) { - buffers[i] = x.specialBuffer(); - shapes[i] = x.specialShapeInfo(); - hostShapes[i] = x.shapeInfo(); - } - - Nd4jPointer extra[2]; - extra[1] = x.getContext()->getCudaStream(); - - ::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) { - - auto x = NDArrayFactory::create('c', {2,3}, {1,2,3,4,5,6}); - auto y = NDArrayFactory::create('c', {1,3}, {7,8,9}); - auto z = NDArrayFactory::create('f', {3, 3}); - - - std::vector buffers(2); - std::vector shapes(2); - std::vector hostShapes(2); - - x.syncToHost(); - y.syncToHost(); - z.syncToHost(); - - buffers[0] = x.specialBuffer(); shapes[0] = x.specialShapeInfo(); hostShapes[0] = x.shapeInfo(); - buffers[1] = y.specialBuffer(); shapes[1] = y.specialShapeInfo(); hostShapes[1] = y.shapeInfo(); - - Nd4jPointer extra[2]; - extra[1] = x.getContext()->getCudaStream(); - - ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) { - - auto x = NDArrayFactory::create('c', {2,3}, {1,2,3,4,5,6}); - auto y = NDArrayFactory::create('c', {1,3}, {7,8,9}); - auto z = NDArrayFactory::create('c', {3, 3}); - - x.syncToHost(); - y.syncToHost(); - z.syncToHost(); - - std::vector buffers(2); - std::vector shapes(2); - std::vector hostShapes(2); - - buffers[0] = x.specialBuffer(); shapes[0] = x.specialShapeInfo(); hostShapes[0] = x.shapeInfo(); - buffers[1] = y.specialBuffer(); shapes[1] = y.specialShapeInfo(); hostShapes[1] = y.shapeInfo(); - - Nd4jPointer extra[2]; - extra[1] = x.getContext()->getCudaStream(); - - ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) { - - auto x = NDArrayFactory::create('c', {1,2,3}, {1,2,3,4,5,6}); - auto y = NDArrayFactory::create('c', {1,2,3}, {7,8,9,10,11, 12}); - - auto z = NDArrayFactory::create('c', {2, 2, 3}); - auto stream = x.getContext()->getCudaStream();//reinterpret_cast(&nativeStream); - std::vector buffers(2); - std::vector shapes(2); - std::vector hostShapes(2); - - buffers[0] = x.specialBuffer(); shapes[0] = x.specialShapeInfo(); hostShapes[0] = x.shapeInfo(); - buffers[1] = y.specialBuffer(); shapes[1] = y.specialShapeInfo(); hostShapes[1] = y.shapeInfo(); - - Nd4jPointer extra[2]; - extra[1] = x.getContext()->getCudaStream(); - - ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) { - - auto x1 = NDArrayFactory::create('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12}); - auto x2 = NDArrayFactory::create('c', {1,2,3}, {13,14,15,16,17, 18}); - auto x3 = NDArrayFactory::create('c', {1,2,3}, {19,20,21,22,23, 24}); - - x1.syncToHost(); - x2.syncToHost(); - x3.syncToHost(); - - auto z = NDArrayFactory::create('c', {4, 2, 3}); - - std::vector buffers(3); - std::vector shapes(3); - std::vector hostShapes(3); - - buffers[0] = x1.specialBuffer(); shapes[0] = x1.specialShapeInfo(); hostShapes[0] = x1.shapeInfo(); - buffers[1] = x2.specialBuffer(); shapes[1] = x2.specialShapeInfo(); hostShapes[1] = x2.shapeInfo(); - buffers[2] = x3.specialBuffer(); shapes[2] = x3.specialShapeInfo(); hostShapes[2] = x3.shapeInfo(); - - Nd4jPointer extra[2]; - extra[1] = x1.getContext()->getCudaStream(); - - ::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) { - - auto x1 = NDArrayFactory::create(1); - auto x2 = NDArrayFactory::create(2); - auto x3 = NDArrayFactory::create(3); - - auto z = NDArrayFactory::create('c', {3}, {1,2,3}); - - x1.syncToHost(); - x2.syncToHost(); - x3.syncToHost(); - - std::vector buffers(3); - std::vector shapes(3); - std::vector hostShapes(3); - - buffers[0] = x1.specialBuffer(); shapes[0] = x1.specialShapeInfo(); hostShapes[0] = x1.shapeInfo(); - buffers[1] = x2.specialBuffer(); shapes[1] = x2.specialShapeInfo(); hostShapes[1] = x2.shapeInfo(); - buffers[2] = x3.specialBuffer(); shapes[2] = x3.specialShapeInfo(); hostShapes[2] = x3.shapeInfo(); - - Nd4jPointer extra[2]; - extra[1] = x1.getContext()->getCudaStream(); - - ::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) { - - auto totalCount = 1000; - auto width = 300; - std::vector lx(totalCount); - for (int i = 0; i < totalCount; i++) { - lx[i] = NDArrayFactory::create('c', {1, width}); - lx[i].assign(i); - lx[i].syncToHost(); - } - - auto z = NDArrayFactory::create('c', {totalCount, width}); - - std::vector buffers(totalCount); - std::vector shapes(totalCount); - std::vector hostShapes(totalCount); - - for (size_t i = 0; i < lx.size(); i++) { - buffers[i] = lx[i].specialBuffer(); - shapes[i] = lx[i].specialShapeInfo(); - hostShapes[i] = lx[i].shapeInfo(); - } - - Nd4jPointer extra[2]; - extra[1] = nd4j::LaunchContext::defaultContext()->getCudaStream(); - - ::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); -} - -TEST_F(NDArrayCudaBasicsTests, TestTear_1) { - auto input = NDArrayFactory::create('c', {1, 10, 10}); - std::vector arrays; // = {NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10})}; - int total = 151; - for (int e = 0; e < total; e++) { - input.assign(e); - arrays.emplace_back(input); - } - auto z = NDArrayFactory::create('c', {total, 10, 10}); - - Nd4jPointer extra[1]; - extra[1] = input.getContext()->getCudaStream(); - - std::vector buffers(total); - std::vector shapes(total); - std::vector hostShapes(total); - - for (size_t i = 0; i < buffers.size(); i++) { - buffers[i] = arrays[i].specialBuffer(); - shapes[i] = arrays[i].specialShapeInfo(); - hostShapes[i] = arrays[i].shapeInfo(); - } - - ::concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); - nd4j::ops::tear op; - - auto result = op.execute({&z}, {}, {1, 2}); - //ASSERT_EQ(10, result->size()); - auto e = result->size() - 1; - //for (size_t e = 0; e < result->size(); e++) { -// arrays[e].printIndexedBuffer("Input list at 40"); -// result->at(e)->printIndexedBuffer("OUtput TEAR at 40"); - //} -// ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); - - delete result; -// delete tads; -} - -TEST_F(NDArrayCudaBasicsTests, TestTear_2) { - - auto input = NDArrayFactory::create('c', {1, 10, 10}); - - std::vector arrays; // = {NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10}), NDArrayFactory::create('c', {1, 10, 10})}; - for (int e = 0; e < 10; e++) { - input.assign(e); - arrays.emplace_back(input); - arrays[e].syncToHost(); - } - - auto z = NDArrayFactory::create('c', {10, 10, 10}); - - Nd4jPointer extra[2]; - extra[1] = input.getContext()->getCudaStream(); - - std::vector buffers(10); - std::vector shapes(10); - std::vector hostShapes(10); - - for (size_t i = 0; i < buffers.size(); i++) { - buffers[i] = arrays[i].specialBuffer(); - shapes[i] = arrays[i].specialShapeInfo(); - hostShapes[i] = arrays[i].shapeInfo(); - } - - std::vector dimsToExclude({1,2}); - - - ::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); - - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimsToExclude); - //std::vector arraysData(arrays.size()); - Nd4jPointer* arraysData; - cudaError_t err = cudaMalloc(&arraysData, arrays.size() * sizeof(void*)); - if (err != 0) { - printf("Cannot allocate device memory for targets due error %d\n", err); - ASSERT_TRUE(false); - } - for (size_t i = 0; i < arrays.size(); i++) { - Nd4jPointer target = arrays[i].specialBuffer(); - cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice); - } - ::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets()); -// auto result = op.execute({&z}, {}, {1, 2}); - - //ASSERT_EQ(10, result->size()); - err = cudaFree(arraysData); - if (err != 0) { - printf("Cannot deallocate device memory for targets due error %d\n", err); - ASSERT_TRUE(false); - } - -// ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e))); - -// delete result; -// delete tads; -} +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index fe190d9bb..9aac42ddf 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -992,81 +992,6 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) { ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15)); } -TEST_F(NativeOpsTests, FlattenTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {2, 5,5}); - auto z = NDArrayFactory::create('c', {2, 5,5}); - - Nd4jPointer extra[6]; -#ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; -#endif - x.linspace(1.0,2); - y.linspace(2,2); - - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - auto dimension = NDArrayFactory::create({0, 1}); - auto dimensions = reinterpret_cast(dimension.buffer()); - auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - exp(1, {0}).linspace(1,2); - ::flatten(extra, - 25, 'c', z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - x.buffer(), x.shapeInfo(), - x.specialBuffer(), x.specialShapeInfo()); - -// exp.printIndexedBuffer("Exp"); -// z.printIndexedBuffer("Flatten"); - ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(NativeOpsTests, ConcatTest_1) { - auto x = NDArrayFactory::create('c', {5, 5}); - auto y = NDArrayFactory::create('c', {5, 5}); - auto exp = NDArrayFactory::create('c', {10,5}); - auto z = NDArrayFactory::create('c', {10,5}); - - Nd4jPointer extra[6]; -#ifdef __CUDABLAS__ - extra[1] = x.getContext()->getCudaStream(); - extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; - x.syncToHost(); - y.syncToHost(); - printf("Unsupported for CUDA platform yet.\n"); - return; -#endif - x.linspace(1.0); - y.linspace(26); - - //y.assign(2.); - x.syncToDevice(); - z.syncToDevice(); - int d = 0; - auto dimension = NDArrayFactory::create('c', {1}, {d}); - auto dimensions = reinterpret_cast(dimension.buffer()); - //auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); - auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); - exp.linspace(1); - Nd4jPointer datas[] = {x.buffer(), y.buffer()}; - Nd4jPointer shapes[] = {x.shapeInfo(), y.shapeInfo()}; - - ::concat(extra, - 0, 2, datas, shapes, nullptr, nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), - nullptr, nullptr); - -// exp.printIndexedBuffer("Exp"); -// z.printIndexedBuffer("Concat"); - ASSERT_TRUE(exp.equalsTo(z)); -} - TEST_F(NativeOpsTests, ConcatTest_2) { auto x = NDArrayFactory::create('c', {5, 5}); auto y = NDArrayFactory::create('c', {5, 5}); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 2d9a23e59..b76538afd 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -41,6 +42,8 @@ #include #include +#include + using namespace nd4j; using namespace nd4j::graph; @@ -55,3 +58,26 @@ public: } }; +/* +TEST_F(PlaygroundTests, test_relubp_1) { + auto x = NDArrayFactory::create('c', {128, 64, 224, 224}); + auto y = x.ulike(); + auto z = x.ulike(); + RandomGenerator rng(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, -1.0, 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &y, -1.0, 1.0); + + int iterations = 10; + + auto timeStart = std::chrono::system_clock::now(); + for (int e = 0; e < iterations; e++) + ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, &z); + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + auto time = (Nd4jLong) outerTime / iterations; + auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / 1024 / 1024; + + nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw); +} +*/ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 71bbd26ee..eb3424007 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -530,7 +530,7 @@ public abstract class DifferentialFunction { public SDVariable arg(int num){ SDVariable[] args = args(); Preconditions.checkNotNull(args, "Arguments are null for function %s", this.getOwnName()); - Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s)", args.length); + Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s), got %s", args.length, num); return args[num]; } @@ -547,8 +547,11 @@ public abstract class DifferentialFunction { /** * Resolve properties and arguments right before execution of * this operation. + * + * @deprecated Will be removed in the future. Ops should support array arguments. Should not bs used or overridden. */ - public void resolvePropertiesFromSameDiffBeforeExecution() { + @Deprecated + public final void resolvePropertiesFromSameDiffBeforeExecution() { val properties = sameDiff.propertiesToResolveForFunction(this); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this); val currentFields = this.propertiesForFunction(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 3bf1754db..7ffeca762 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -133,23 +133,8 @@ import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance; import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; -import org.nd4j.linalg.api.ops.impl.scalar.LogX; +import org.nd4j.linalg.api.ops.impl.scalar.*; import org.nd4j.linalg.api.ops.impl.scalar.Pow; -import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative; -import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; -import org.nd4j.linalg.api.ops.impl.scalar.Relu6; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction; -import org.nd4j.linalg.api.ops.impl.scalar.Step; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual; @@ -163,7 +148,6 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; -import org.nd4j.linalg.api.ops.impl.shape.Broadcast; import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; import org.nd4j.linalg.api.ops.impl.shape.Cross; @@ -790,20 +774,20 @@ public class DifferentialFunctionFactory { return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable(); } - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, bias, dimensions).outputVariable(); + public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { + return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable(); } - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, bias, gradient, dimensions).outputVariables(); + public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) { + return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables(); } - public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, dimensions).outputVariable(); + public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { + return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable(); } - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, gradient, dimensions).outputVariables(); + public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) { + return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables(); } public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) { @@ -1235,19 +1219,19 @@ public class DifferentialFunctionFactory { return new Xor(sameDiff(), ix, iy).outputVariable(); } - public SDVariable shift(SDVariable ix, int shift) { + public SDVariable shift(SDVariable ix, SDVariable shift) { return new ShiftBits(sameDiff(), ix, shift).outputVariable(); } - public SDVariable rshift(SDVariable ix, int shift) { + public SDVariable rshift(SDVariable ix, SDVariable shift) { return new RShiftBits(sameDiff(), ix, shift).outputVariable(); } - public SDVariable rotl(SDVariable ix, int shift) { + public SDVariable rotl(SDVariable ix, SDVariable shift) { return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable(); } - public SDVariable rotr(SDVariable ix, int shift) { + public SDVariable rotr(SDVariable ix, SDVariable shift) { return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); } @@ -1341,6 +1325,10 @@ public class DifferentialFunctionFactory { return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable(); } + public SDVariable reluDerivative(SDVariable input, SDVariable grad){ + return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); + } + public SDVariable relu6(SDVariable iX, double cutoff) { return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); } @@ -1450,14 +1438,6 @@ public class DifferentialFunctionFactory { return new MatrixInverse(sameDiff(), in, false).outputVariable(); } - public SDVariable broadcast(SDVariable iX, int... shape) { - return broadcast(iX, ArrayUtil.toLongArray(shape)); - } - - public SDVariable broadcast(SDVariable iX, long... shape) { - return new Broadcast(sameDiff(), iX, shape).outputVariable(); - } - public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java index 1156de102..34b305001 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/ListenerVariables.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.listeners; -import com.google.common.collect.Sets; +import org.nd4j.shade.guava.collect.Sets; import java.util.Arrays; import java.util.HashSet; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java index 1932a6c75..7dbb0119d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java @@ -1,7 +1,7 @@ package org.nd4j.autodiff.listeners.checkpoint; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java index 36334b648..b063e18a2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java @@ -16,10 +16,10 @@ package org.nd4j.autodiff.listeners.records; -import com.google.common.base.Predicates; -import com.google.common.collect.Collections2; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.base.Predicates; +import org.nd4j.shade.guava.collect.Collections2; +import org.nd4j.shade.guava.collect.ImmutableMap; +import org.nd4j.shade.guava.collect.Lists; import java.util.ArrayList; import java.util.Collection; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 430b4d83a..0d2700b43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -111,9 +111,6 @@ public class SDVariable implements Serializable { return variableType == VariableType.CONSTANT; } - - - /** * Allocate and return a new array * based on the vertex id and weight initialization. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index e6f30d12e..e09ceda75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -16,11 +16,11 @@ package org.nd4j.autodiff.samediff; -import com.google.common.base.Predicates; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Maps; -import com.google.common.collect.Table; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.base.Predicates; +import org.nd4j.shade.guava.collect.HashBasedTable; +import org.nd4j.shade.guava.collect.Maps; +import org.nd4j.shade.guava.collect.Table; +import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import lombok.*; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 1bde174c1..7da89aa36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.samediff.ops; -import com.google.common.collect.Sets; +import org.nd4j.shade.guava.collect.Sets; import java.util.HashMap; import java.util.Map; import java.util.Set; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 70eaa5cd9..10fc0b44a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2428,7 +2428,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitShift(String name, SDVariable x, int shift) { + public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { validateInteger("shift_bits", x); SDVariable result = f().shift(x, shift); return updateVariableNameAndReference(result, name); @@ -2441,7 +2441,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitShiftRight(String name, SDVariable x, int shift) { + public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { validateInteger("rshift_bits", x); SDVariable result = f().rshift(x, shift); return updateVariableNameAndReference(result, name); @@ -2454,7 +2454,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitRotl(String name, SDVariable x, int shift) { + public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { validateInteger("cyclic_shift_bits", x); SDVariable result = f().rotl(x, shift); return updateVariableNameAndReference(result, name); @@ -2467,7 +2467,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitRotr(String name, SDVariable x, int shift) { + public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { validateInteger("cyclic_rshift_bits", x); SDVariable result = f().rotr(x, shift); return updateVariableNameAndReference(result, name); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 928bf3e6e..eb89a0f3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -759,8 +759,8 @@ public class SDNN extends SDOps { * * @return Output variable */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { - return layerNorm(null, input, gain, bias, dimensions); + public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { + return layerNorm(null, input, gain, bias, channelsFirst, dimensions); } /** @@ -772,13 +772,15 @@ public class SDNN extends SDOps { * @param input Input variable * @param gain gain * @param bias bias + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs * @return Output variable */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { validateFloatingPoint("layerNorm", "input", input); validateFloatingPoint("layerNorm", "gain", gain); validateFloatingPoint("layerNorm", "bias", bias); - SDVariable result = f().layerNorm(input, gain, bias, dimensions); + SDVariable result = f().layerNorm(input, gain, bias, channelsFirst, dimensions); return updateVariableNameAndReference(result, name); } @@ -789,8 +791,8 @@ public class SDNN extends SDOps { * * @return Output variable */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, int... dimensions) { - return layerNorm((String)null, input, gain, dimensions); + public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { + return layerNorm((String)null, input, gain, channelsFirst, dimensions); } /** @@ -803,10 +805,10 @@ public class SDNN extends SDOps { * @param gain gain * @return Output variable */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, int... dimensions) { + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { validateFloatingPoint("layerNorm", "input", input); validateFloatingPoint("layerNorm", "gain", gain); - SDVariable result = f().layerNorm(input, gain, dimensions); + SDVariable result = f().layerNorm(input, gain, channelsFirst, dimensions); return updateVariableNameAndReference(result, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java index ef1bfefb7..6faf29bfc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/FlatBuffersMapper.java @@ -16,7 +16,7 @@ package org.nd4j.autodiff.samediff.serde; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import java.nio.ByteOrder; import java.util.*; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index e9ad61c04..5bc175952 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -16,8 +16,8 @@ package org.nd4j.autodiff.validation; -import com.google.common.collect.ImmutableSet; -import com.google.common.reflect.ClassPath; +import org.nd4j.shade.guava.collect.ImmutableSet; +import org.nd4j.shade.guava.reflect.ClassPath; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -341,8 +341,8 @@ public class OpValidation { //Finally: check execution/output - Map outOrig = original.execAll(tc.placeholderValues()); - Map outDe = deserialized.execAll(tc.placeholderValues()); + Map outOrig = original.outputAll(tc.placeholderValues()); + Map outDe = deserialized.outputAll(tc.placeholderValues()); Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model"); for(String s : outOrig.keySet()){ @@ -560,7 +560,7 @@ public class OpValidation { ImmutableSet info; try { //Dependency note: this ClassPath class was added in Guava 14 - info = com.google.common.reflect.ClassPath.from(DifferentialFunctionClassHolder.class.getClassLoader()) + info = org.nd4j.shade.guava.reflect.ClassPath.from(DifferentialFunctionClassHolder.class.getClassLoader()) .getTopLevelClassesRecursive("org.nd4j.linalg.api.ops"); } catch (IOException e) { //Should never happen diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java index e07b5c9b8..fd08e4270 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java @@ -42,6 +42,7 @@ import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; +import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.databind.module.SimpleModule; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; @@ -80,7 +81,8 @@ public abstract class BaseEvaluation implements IEvalu .withFieldVisibility(JsonAutoDetect.Visibility.ANY) .withGetterVisibility(JsonAutoDetect.Visibility.NONE) .withSetterVisibility(JsonAutoDetect.Visibility.NONE) - .withCreatorVisibility(JsonAutoDetect.Visibility.NONE)); + .withCreatorVisibility(JsonAutoDetect.Visibility.ANY) + ); return ret; } @@ -107,15 +109,15 @@ public abstract class BaseEvaluation implements IEvalu public static T fromJson(String json, Class clazz) { try { return objectMapper.readValue(json, clazz); - } catch (IllegalArgumentException e) { - if (e.getMessage().contains("Invalid type id")) { + } catch (InvalidTypeIdException e) { + if (e.getMessage().contains("Could not resolve type id")) { try { return (T) attempFromLegacyFromJson(json, e); } catch (Throwable t) { throw new RuntimeException("Cannot deserialize from JSON - JSON is invalid?", t); } } - throw e; + throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } @@ -129,7 +131,7 @@ public abstract class BaseEvaluation implements IEvalu * @param json JSON to attempt to deserialize * @param originalException Original exception to be re-thrown if it isn't legacy JSON */ - protected static T attempFromLegacyFromJson(String json, IllegalArgumentException originalException) { + protected static T attempFromLegacyFromJson(String json, InvalidTypeIdException originalException) throws InvalidTypeIdException { if (json.contains("org.deeplearning4j.eval.Evaluation")) { String newJson = json.replaceAll("org.deeplearning4j.eval.Evaluation", "org.nd4j.evaluation.classification.Evaluation"); return (T) fromJson(newJson, Evaluation.class); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java index dafe571f0..01fd322e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java @@ -16,8 +16,8 @@ package org.nd4j.evaluation.classification; -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; +import org.nd4j.shade.guava.collect.HashMultiset; +import org.nd4j.shade.guava.collect.Multiset; import lombok.Getter; import java.io.Serializable; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java index fed8c63b4..26ef8bba9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/CustomEvaluation.java @@ -16,7 +16,7 @@ package org.nd4j.evaluation.custom; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import java.io.Serializable; import java.util.ArrayList; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java index cbad73da3..079755055 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java @@ -16,7 +16,7 @@ package org.nd4j.evaluation.custom; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java index 6acc7fc4e..a8f5c32e9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java @@ -16,7 +16,7 @@ package org.nd4j.evaluation.serde; -import com.google.common.collect.Multiset; +import org.nd4j.shade.guava.collect.Multiset; import org.nd4j.evaluation.classification.ConfusionMatrix; import org.nd4j.shade.jackson.core.JsonGenerator; import org.nd4j.shade.jackson.core.JsonProcessingException; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index a92066d7a..82bfdc843 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -64,6 +64,7 @@ public class DifferentialFunctionClassHolder { add("outputVariables"); add("tArguments"); add("iArguments"); + add("bArguments"); add("hash"); add("opName"); add("sameDiff"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 3f270e342..da580b748 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -98,7 +98,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class, org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class, - org.nd4j.linalg.api.ops.impl.layers.Linear.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm.class, @@ -229,6 +228,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.scalar.Pow.class, org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class, + org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.Relu6.class, org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class, @@ -267,7 +267,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.scatter.ScatterSub.class, org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate.class, org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent.class, - org.nd4j.linalg.api.ops.impl.shape.Broadcast.class, org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape.class, org.nd4j.linalg.api.ops.impl.shape.Concat.class, org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java index 719ac792d..7a651fb88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java @@ -18,9 +18,9 @@ package org.nd4j.imports.graphmapper.onnx; import org.nd4j.shade.protobuf.ByteString; import org.nd4j.shade.protobuf.Message; -import com.google.common.primitives.Floats; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Floats; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index f57fef4c7..3ad3267c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -17,8 +17,8 @@ package org.nd4j.imports.graphmapper.tf; import org.nd4j.shade.protobuf.Message; -import com.google.common.primitives.Floats; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Floats; +import org.nd4j.shade.guava.primitives.Ints; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.functions.DifferentialFunction; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java index 36313593d..eebee472c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; -import org.nd4j.linalg.api.ops.impl.scalar.Step; +import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,8 +41,7 @@ public class ActivationReLU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in)); - dLdz.muli(epsilon); + INDArray dLdz = Nd4j.exec(new RectifiedLinearDerivative(in, epsilon, in.ulike()))[0]; return new Pair<>(dLdz, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 121ca2b43..46dd786c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -17,8 +17,8 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import com.google.flatbuffers.FlatBufferBuilder; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; @@ -142,21 +142,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { } - /** - * Returns true if this array is compressed, and false otherwise - * @return - */ @Override public boolean isCompressed() { return compressed; } - /** - * This method marks INDArray instance as compressed - * PLEASE NOTE: Do not use this method unless you 100% have to - * - * @param reallyCompressed - */ @Override public void markAsCompressed(boolean reallyCompressed) { this.compressed = reallyCompressed; @@ -949,17 +939,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public int elementWiseStride() { - /* - if(Shape.elementWiseStride(shapeInfo()) < 0 && !attemptedToFindElementWiseStride) { - INDArray reshapeAttempt = Shape.newShapeNoCopy(this,new int[]{1,length()}, ordering() == 'f'); - if(reshapeAttempt != null && reshapeAttempt.elementWiseStride() > 0) { - Shape.setElementWiseStride(shapeInfo(), reshapeAttempt.stride(-1)); - this.shapeInformation = Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), offset(),reshapeAttempt.stride(-1), ordering()); - } - attemptedToFindElementWiseStride = true; - - } - */ return Shape.elementWiseStride(shapeInfoDataBuffer()); } @@ -1028,19 +1007,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return toTad; } - /** - * Get the vector along a particular dimension - * - * @param index the index of the vector to getScalar - * @param dimension the dimension to getScalar the vector from - * @return the vector along a particular dimension - */ - @Override - @Deprecated - public INDArray javaTensorAlongDimension(int index, int... dimension) { - return doTad(index, dimension); - } - private void setShapeInformation(Pair shapeInfo) { this.shapeInformation = shapeInfo.getFirst(); this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond()); @@ -1131,14 +1097,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return ret2.permutei(finalPermuteDims); } - - - /** - * Returns the number of possible vectors for a given dimension - * - * @param dimension the dimension to calculate the number of vectors for - * @return the number of possible vectors along a dimension - */ @Override public long vectorsAlongDimension(int dimension) { if (dimension == 0 && isVector() || isRowVectorOrScalar()) @@ -1171,17 +1129,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return length / size(dimension); } - /** - * Get the vector along a particular dimension - * - * @param index the index of the vector to get - * @param dimension the dimension to get the vector from - * @return the vector along a particular dimension - */ @Override public INDArray vectorAlongDimension(int index, int dimension) { - if (dimension < 0) + if (dimension < 0) { dimension = jvmShapeInfo.getRank() + dimension; + } //return the whole thing if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2 @@ -1189,12 +1141,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; } - INDArray ret = tensorAlongDimension(index, dimension); - //if (isMatrix() && ret.isVector() && dimension == 1 && !ret.isRowVector()) - // return ret.reshape(ArrayUtil.reverseCopy(ret.shape())); - //else if (isMatrix() && ret.isVector() && dimension == 0 && !ret.isColumnVector()) - // return ret.reshape(ArrayUtil.reverseCopy(ret.shape())); - return ret; + return tensorAlongDimension(index, dimension); } @Override @@ -1217,13 +1164,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false)); } - - /** - * Cumulative sum along a dimension - * - * @param dimension the dimension to perform cumulative sum along - * @return the cumulative sum along the specified dimension - */ @Override public INDArray cumsumi(int dimension) { validateNumericalArray("cumsumi", true); @@ -1372,25 +1312,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return logEntropy(Integer.MAX_VALUE).getDouble(0); } - /** - * Cumulative sum along a dimension (in place) - * - * @param dimension the dimension to perform cumulative sum along - * @return the cumulative sum along the specified dimension - */ @Override public INDArray cumsum(int dimension) { validateNumericalArray("cumsum", true); return dup().cumsumi(dimension); } - /** - * Assign all of the elements in the given - * ndarray to this ndarray - * - * @param arr the elements to assign - * @return this - */ @Override public INDArray assign(final INDArray arr) { Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()), @@ -1399,7 +1326,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal"); - //Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set(this, arr, this, length())); Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this)); return this; } @@ -1434,7 +1360,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray putScalar(long i, float value) { return putScalar(i, (double) value); - } @Override @@ -1561,7 +1486,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; } - @Override public INDArray putScalar(int[] indexes, float value) { return putScalar(indexes, (double) value); @@ -1577,27 +1501,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return putScalar(indexes, (double) value); } - /** - * Returns an ndarray with 1 if the element is epsilon equals - * - * @param other the number to compare - * @return a copied ndarray with the given - * binary conditions - */ @Override public INDArray eps(Number other) { validateNumericalArray("eps", true); return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } - /** - * epsilon equals than comparison: - * If the given number is less than the - * comparison number the item is 0 otherwise 1 - * - * @param other the number to compare - * @return - */ @Override public INDArray eps(INDArray other) { validateNumericalArray("eps", true); @@ -1634,7 +1543,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); } - @Override public INDArray lt(INDArray other) { validateNumericalArray("less than (lt)", false); @@ -1696,9 +1604,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan())); } - /** - * Negate each element. - */ @Override public INDArray neg() { validateNumericalArray("negative (neg)", true); @@ -1707,9 +1612,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()))); } - /** - * Negate each element (in-place). - */ @Override public INDArray negi() { validateNumericalArray("negative (negi)", true); @@ -1809,16 +1711,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - - - /** - * Returns the element at the specified row/column - * This will throw an exception if the - * - * @param row the row of the element to return - * @param column the row of the element to return - * @return a scalar indarray of the element at this index - */ @Override public INDArray getScalar(long row, long column) { return getScalar(new long[] {row, column}); @@ -1840,6 +1732,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { if(isEmpty()) return this; + Nd4j.getCompressor().autoDecompress(this); + // fixme: eventually it would be nice to have this in native code if (isS()) { val list = new ArrayList(); @@ -1849,8 +1743,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.create(list, this.shape(), this.ordering()); } - Nd4j.getCompressor().autoDecompress(this); - return Shape.toOffsetZeroCopy(this, order); + val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order); + z.assign(this); + return z; } /** @@ -1983,27 +1878,20 @@ public abstract class BaseNDArray implements INDArray, Iterable { } - /** - * Inserts the element at the specified index - * - * @param indices the indices to insert into - * @param element a scalar ndarray - * @return a scalar ndarray of the element at this index - */ @Override public INDArray put(int[] indices, INDArray element) { Nd4j.getCompressor().autoDecompress(this); if (!element.isScalar()) throw new IllegalArgumentException("Unable to insert anything but a scalar"); if (isRowVector() && indices[0] == 0 && indices.length == 2) { - int ix = 0; //Shape.offset(javaShapeInformation); + int ix = 0; for (int i = 1; i < indices.length; i++) ix += indices[i] * stride(i); if (ix >= data.length()) throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); data.put(ix, element.getDouble(0)); } else { - int ix = 0; //Shape.offset(javaShapeInformation); + int ix = 0; for (int i = 0; i < indices.length; i++) if (size(i) != 1) ix += indices[i] * stride(i); @@ -2011,10 +1899,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); data.put(ix, element.getDouble(0)); } - - return this; - } @Override @@ -2068,39 +1953,16 @@ public abstract class BaseNDArray implements INDArray, Iterable { return putWhereWithMask(mask,Nd4j.scalar(put)); } - /** - * Inserts the element at the specified index - * - * @param i the row insert into - * @param j the column to insert into - * @param element a scalar ndarray - * @return a scalar ndarray of the element at this index - */ @Override public INDArray put(int i, int j, INDArray element) { return put(new int[] {i, j}, element); } - /** - * Inserts the element at the specified index - * - * @param i the row insert into - * @param j the column to insert into - * @param element a scalar ndarray - * @return a scalar ndarray of the element at this index - */ @Override public INDArray put(int i, int j, Number element) { return putScalar(new int[] {i, j}, element.doubleValue()); } - /** - * Assigns the given matrix (put) to the specified slice - * - * @param slice the slice to assign - * @param put the slice to put - * @return this for chainability - */ @Override public INDArray putSlice(int slice, INDArray put) { Nd4j.getCompressor().autoDecompress(this); @@ -2200,10 +2062,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getStrides(shape, ordering); } - - /** - * Returns the square of the Euclidean distance. - */ @Override public double squaredDistance(INDArray other) { validateNumericalArray("squaredDistance", false); @@ -2211,9 +2069,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return d2 * d2; } - /** - * Returns the (euclidean) distance. - */ @Override public double distance2(INDArray other) { validateNumericalArray("distance2", false); @@ -2221,9 +2076,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue(); } - /** - * Returns the (1-norm) distance. - */ @Override public double distance1(INDArray other) { validateNumericalArray("distance1", false); @@ -2231,8 +2083,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue(); } - - @Override public INDArray get(INDArray indices) { if(indices.rank() > 2) { @@ -2288,49 +2138,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { } - @Override - public INDArray get(List> indices) { - INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()]; - for(int i = 0; i < indArrayIndices.length; i++) { - indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i))); - } - - boolean hasNext = true; - Generator>> iterate = SpecifiedIndex.iterate(indArrayIndices); - List resultList = new ArrayList<>(); - while(hasNext) { - try { - List> next = iterate.next(); - int[][] nextArr = new int[next.size()][]; - for(int i = 0; i < next.size(); i++) { - nextArr[i] = Ints.toArray(next.get(i)); - } - - int[] curr = Ints.concat(nextArr); - INDArray currSlice = this; - for(int j = 0; j < curr.length; j++) { - currSlice = currSlice.slice(curr[j]); - } - - //slice drops the first dimension, this adds a 1 to match normal numpy behavior - currSlice = currSlice.reshape(Longs.concat(new long[]{1},currSlice.shape())); - - resultList.add(currSlice); - - - } - catch(NoSuchElementException e) { - hasNext = false; - } - } - - - - - return Nd4j.concat(0,resultList.toArray(new INDArray[resultList.size()])); - } - - @Override public INDArray put(INDArray indices, INDArray element) { if(indices.rank() > 2) { @@ -2343,7 +2150,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next())); } - } else { List arrList = new ArrayList<>(); @@ -2356,8 +2162,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice})); arrList.add(slice(row.getInt(j))); } - - } } else if(indices.isRowVector()) { @@ -2365,15 +2169,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { arrList.add(slice(indices.getInt(i))); } } - } - - return this; - } - @Override public INDArray put(INDArrayIndex[] indices, INDArray element) { Nd4j.getCompressor().autoDecompress(this); @@ -2441,7 +2240,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; } - /** * Mainly here for people coming from numpy. * This is equivalent to a call to permute @@ -2544,7 +2342,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return ret; } - protected void init(int[] shape, int[] stride) { //null character if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { @@ -2564,7 +2361,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { } - @Override public INDArray getScalar(long i) { if (i >= this.length()) @@ -3034,13 +2830,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return dup().rsubiRowVector(rowVector); } - /** - * Inserts the element at the specified index - * - * @param i the index insert into - * @param element a scalar ndarray - * @return a scalar ndarray of the element at this index - */ @Override public INDArray put(int i, INDArray element) { Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element); @@ -3804,12 +3593,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return normmax(false, dimension); } - /** - * Reverse division - * - * @param other the matrix to divide from - * @return - */ @Override public INDArray rdiv(INDArray other) { validateNumericalArray("rdiv", false); @@ -3820,37 +3603,17 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * Reverse divsion (in place) - * - * @param other - * @return - */ @Override public INDArray rdivi(INDArray other) { return rdivi(other, this); } - /** - * Reverse division - * - * @param other the matrix to subtract from - * @param result the result ndarray - * @return - */ @Override public INDArray rdiv(INDArray other, INDArray result) { validateNumericalArray("rdiv", false); return dup().rdivi(other, result); } - /** - * Reverse division (in-place) - * - * @param other the other ndarray to subtract - * @param result the result ndarray - * @return the ndarray with the operation applied - */ @Override public INDArray rdivi(INDArray other, INDArray result) { validateNumericalArray("rdivi", false); @@ -3859,23 +3622,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * Reverse subtraction - * - * @param other the matrix to subtract from - * @param result the result ndarray - * @return - */ @Override public INDArray rsub(INDArray other, INDArray result) { validateNumericalArray("rsub", false); return rsubi(other, result); } - /** - * @param other - * @return - */ @Override public INDArray rsub(INDArray other) { validateNumericalArray("rsub", false); @@ -3886,22 +3638,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * @param other - * @return - */ @Override public INDArray rsubi(INDArray other) { return rsubi(other, this); } - /** - * Reverse subtraction (in-place) - * - * @param other the other ndarray to subtract - * @param result the result ndarray - * @return the ndarray with the operation applied - */ @Override public INDArray rsubi(INDArray other, INDArray result) { validateNumericalArray("rsubi", false); @@ -3910,12 +3651,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return result; } - /** - * Set the value of the ndarray to the specified value - * - * @param value the value to assign - * @return the ndarray with the values - */ @Override public INDArray assign(Number value) { Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " + @@ -3924,34 +3659,17 @@ public abstract class BaseNDArray implements INDArray, Iterable { return this; } - @Override public INDArray assign(boolean value) { return assign(value ? 1 : 0); } - - /** - * Assign all elements from given ndarray that are matching given condition, - * ndarray to this ndarray - * - * @param arr the elements to assign - * @param condition - * @return this - */ @Override public INDArray assignIf(INDArray arr, Condition condition) { BooleanIndexing.assignIf(this, arr, condition); return this; } - /** - * Replaces all elements in this ndarray that are matching give condition, with corresponding elements from given array - * - * @param arr - * @param condition - * @return - */ @Override public INDArray replaceWhere(INDArray arr, Condition condition) { Nd4j.getCompressor().autoDecompress(this); @@ -3960,7 +3678,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { } @Override - @Deprecated + @Deprecated //TODO: Investigate. Not deprecated in the base interface. public long linearIndex(long i) { long idx = i; for (int j = 0; j < jvmShapeInfo.rank - 1; j++) { @@ -4187,15 +3905,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return addi(n, this); } - - - /** - * Replicate and tile array to fill out to the given shape - * See: - * https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358 - * @param shape the new shape of this ndarray - * @return the shape to fill out to - */ @Override public INDArray repmat(int[] shape) { Nd4j.getCompressor().autoDecompress(this); @@ -4223,16 +3932,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return out; } - - /** - * Insert a row in to this array - * Will throw an exception if this - * ndarray is not a matrix - * - * @param row the row insert into - * @param toPut the row to insert - * @return this - */ @Override public INDArray putRow(long row, INDArray toPut) { if (isRowVector() && toPut.isVector()) { @@ -4241,15 +3940,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut); } - /** - * Insert a column in to this array - * Will throw an exception if this - * ndarray is not a matrix - * - * @param column the column to insert - * @param toPut the array to put - * @return this - */ @Override public INDArray putColumn(int column, INDArray toPut) { Nd4j.getCompressor().autoDecompress(this); @@ -4866,11 +4556,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return reshape(length()); } - /** - * Flattens the array for linear indexing - * - * @return the flattened version of this array - */ @Override public void sliceVectors(List list) { if (isVector()) @@ -4918,12 +4603,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return col.reshape(col.length(), 1); } - - /** - * Get whole rows from the passed indices. - * - * @param rindices - */ @Override public INDArray getRows(int[] rindices) { Nd4j.getCompressor().autoDecompress(this); @@ -4940,13 +4619,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { } } - /** - * Returns a subset of this array based on the specified - * indexes - * - * @param indexes the indexes in to the array - * @return a view of the array with the specified indices - */ @Override public INDArray get(INDArrayIndex... indexes) { Nd4j.getCompressor().autoDecompress(this); @@ -5134,13 +4806,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return out; } - - /** - * Get whole columns - * from the passed indices. - * - * @param cindices - */ @Override public INDArray getColumns(int... cindices) { if (!isMatrix() && !isVector()) @@ -5349,12 +5014,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { return jvmShapeInfo.shape; } - /** - * Returns the shape information debugging - * information - * - * @return the shape information debugging information - */ @Override public String shapeInfoToString() { return Shape.shapeToString(this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java index 3570ed7ad..6a112b868 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArray.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import net.ericaro.neoitertools.Generator; @@ -153,7 +153,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { arrList.set(j,put); } } - } } else if(indices.isRowVector()) { @@ -161,12 +160,8 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { arrList.add(slice(indices.getInt(i))); } } - return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()])); - } - - } @Override @@ -259,21 +254,13 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } - @Override - public INDArray get(List> indices) { - return null; - } - @Override public INDArray put(INDArray indices, INDArray element) { INDArrayIndex[] realIndices = new INDArrayIndex[indices.rank()]; for(int i = 0; i < realIndices.length; i++) { realIndices[i] = new SpecifiedIndex(indices.slice(i).dup().data().asInt()); } - - return put(realIndices,element); - } @Override @@ -328,13 +315,11 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return sparseInformation; } - @Override public LongBuffer shapeInfo() { return null; } - @Override public boolean isCompressed() { return false; @@ -364,7 +349,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return Shape.sparseOffsets(sparseInformation); } - @Override public int stride(int dimension) { int rank = Shape.rank(shapeInformation); @@ -414,11 +398,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } - @Override - public INDArray javaTensorAlongDimension(int index, int... dimension) { - return null; - } - @Override public INDArray cumsumi(int dimension) { return null; @@ -479,7 +458,6 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { return null; } - @Override public INDArray isInfinite() { throw new UnsupportedOperationException(); @@ -554,6 +532,7 @@ public abstract class BaseSparseNDArray implements ISparseNDArray { public INDArray lte(Number other) { return null; } + @Override public INDArray lt(INDArray other) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java index 1e85be0cd..da5cd3f60 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java @@ -16,9 +16,9 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import com.google.flatbuffers.FlatBufferBuilder; import net.ericaro.neoitertools.Generator; import org.nd4j.linalg.api.blas.params.MMulTranspose; @@ -568,13 +568,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray { return this; } - /** - * Returns a subset of this array based on the specified - * indexes - * - * @param indexes the indexes in to the array - * @return a view of the array with the specified indices - */ @Override public INDArray get(INDArrayIndex... indexes) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java index bd0f7c905..92e59486c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ndarray; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.*; @@ -124,14 +124,6 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray { return this; } - - /** - * Returns a subset of this array based on the specified - * indexes - * - * @param indexes the indexes in to the array - * @return a view of the array with the specified indices - */ @Override public INDArray get(INDArrayIndex... indexes) { //check for row/column vector and point index being 0 diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 727b5db6d..49cabf6bb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -41,45 +41,44 @@ import org.nd4j.linalg.string.NDArrayStrings; */ public interface INDArray extends Serializable, AutoCloseable { /** - * Returns the shape information debugging - * information - * @return the shape information debugging information + * Returns the shape information debugging information + * @return the shape information. */ String shapeInfoToString(); /** * Shape info - * @return + * @return Shape info */ DataBuffer shapeInfoDataBuffer(); /** * Sparse info - * @return + * @return Sparse info. */ DataBuffer sparseInfoDataBuffer(); /** * Shape info - * @return + * @return Shape info */ LongBuffer shapeInfo(); /** - * Returns true if this array is a view or not - * @return + * Check if this array is a view or not. + * @return true if array is a view. */ boolean isView(); /** - * Returns true if this array is sparse - * @return + * Check if this array is sparse + * @return true if this array is sparse. */ boolean isSparse(); /** - * Returns true if this array is compressed, and false otherwise - * @return + * Check if this array is compressed. + * @return true if this array is compressed. */ boolean isCompressed(); @@ -87,11 +86,10 @@ public interface INDArray extends Serializable, AutoCloseable { * This method marks INDArray instance as compressed * PLEASE NOTE: Do not use this method unless you 100% have to * - * @param reallyCompressed + * @param reallyCompressed new value for compressed. */ void markAsCompressed(boolean reallyCompressed); - /** * Returns the rank of the ndarray (the number of dimensions). * @@ -108,27 +106,31 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Element wise stride + * @return the element wise stride */ int elementWiseStride(); /** - * Get a scalar - * at the given linear offset + * Get a double at the given linear offset unsafe, without checks. * @param offset the offset to get at - * @return this + * @return double value at offset */ - double getDoubleUnsafe(long offset); + double getDoubleUnsafe(long offset); //TODO: consider deleting. + /** + * Get string value at given index. + * @param index index to retreive + * @return string value at index. + */ String getString(long index); /** - * Insert a scalar - * at the given linear offset + * Insert a scalar at the given linear offset * @param offset the offset to insert at * @param value the value to insert * @return this */ - INDArray putScalarUnsafe(long offset, double value); + INDArray putScalarUnsafe(long offset, double value); //TODO: consider deleting. /** * Returns the number of possible vectors for a given dimension @@ -164,17 +166,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray tensorAlongDimension(long index, int... dimension); - /** - * Get the vector along a particular dimension - * - * @param index the index of the vector to getScalar - * @param dimension the dimension to getScalar the vector from - * @return the vector along a particular dimension - */ - @Deprecated - INDArray javaTensorAlongDimension(int index, int... dimension); - - /** * Returns the cumulative sum along a dimension. In-place method. * @@ -192,8 +183,7 @@ public interface INDArray extends Serializable, AutoCloseable { INDArray cumsum(int dimension); /** - * Assign all of the elements in the given - * ndarray to this ndarray + * Assign all of the elements in the given ndarray to this ndarray * * @param arr the elements to assign * @return this @@ -256,9 +246,19 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray putScalar(int[] i, double value); - + /** + * See {@link #putScalar(int[], double)} + */ INDArray putScalar(long[] i, double value); + + /** + * See {@link #putScalar(int[], double)} + */ INDArray putScalar(long[] i, float value); + + /** + * See {@link #putScalar(int[], double)} + */ INDArray putScalar(long[] i, int value); /** @@ -302,7 +302,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray lt(Number other); - /** * Put the specified float value at the specified indices in this array * @@ -329,8 +328,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray eps(Number other); - - /** * Returns the binary ndarray for "Equals" comparison. * @@ -369,10 +366,8 @@ public interface INDArray extends Serializable, AutoCloseable { * @param other the ndarray to compare. * @return the binary ndarray for "Less" comparison. */ - INDArray lt(INDArray other); - /** * Returns the binary ndarray for "Epsilon equals" comparison. * @@ -405,7 +400,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray eq(INDArray other); - /** * Returns the binary ndarray for "Greater Than" comparison. * @@ -426,8 +420,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray isNaN(); - - /** * Returns the ndarray negative (cloned) * @@ -466,7 +458,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rsub(Number n); - /** * Reverse subtraction in place - i.e., (n - thisArrayValues) * @@ -475,7 +466,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rsubi(Number n); - /** * Division by a number * @@ -492,7 +482,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray divi(Number n); - /** * Scalar multiplication (copy) * @@ -509,7 +498,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray muli(Number n); - /** * Scalar subtraction (copied) * @@ -518,7 +506,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray sub(Number n); - /** * In place scalar subtraction * @@ -543,7 +530,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray addi(Number n); - /** * Reverse division (number / ndarray) * @@ -553,7 +539,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rdiv(Number n, INDArray result); - /** * Reverse in place division * @@ -581,11 +566,12 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rsubi(Number n, INDArray result); - /** - * @param n - * @param result - * @return + * Division if ndarray by number + * + * @param n the number to divide by + * @param result the result ndarray + * @return the result ndarray */ INDArray div(Number n, INDArray result); @@ -594,24 +580,35 @@ public interface INDArray extends Serializable, AutoCloseable { * * @param n the number to divide by * @param result the result ndarray - * @return + * @return the result ndarray */ INDArray divi(Number n, INDArray result); - + /** + * Multiplication of ndarray. + * + * @param n the number to multiply by + * @param result the result ndarray + * @return the result ndarray + */ INDArray mul(Number n, INDArray result); - /** * In place multiplication of this ndarray * * @param n the number to divide by * @param result the result ndarray - * @return + * @return the result ndarray */ INDArray muli(Number n, INDArray result); - + /** + * Subtraction of this ndarray + * + * @param n the number to subtract by + * @param result the result ndarray + * @return the result ndarray + */ INDArray sub(Number n, INDArray result); /** @@ -623,6 +620,12 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray subi(Number n, INDArray result); + /** + * Addition of this ndarray. + * @param n the number to add + * @param result the result ndarray + * @return the result ndarray + */ INDArray add(Number n, INDArray result); /** @@ -634,25 +637,24 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray addi(Number n, INDArray result); - /** - * Returns a subset of this array based on the specified - * indexes + * Returns a subset of this array based on the specified indexes * * @param indexes the indexes in to the array * @return a view of the array with the specified indices */ INDArray get(INDArrayIndex... indexes); + //TODO: revisit after #8166 is resolved. /** - * Return a mask on whether each element - * matches the given condition + * Return a mask on whether each element matches the given condition * @param comp * @param condition * @return */ INDArray match(INDArray comp,Condition condition); + //TODO: revisit after #8166 is resolved. /** * Returns a mask * @param comp @@ -681,54 +683,51 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray getWhere(Number comp,Condition condition); + //TODO: unused / untested method. (only used to forward calls from putWhere(Number,INDArray ,Condition). /** - * Assign the element according - * to the comparison array + * Assign the element according to the comparison array * @param comp the comparison array * @param put the elements to put * @param condition the condition for masking on - * @return + * @return a copy of this array with the conditional assignments. */ INDArray putWhere(INDArray comp,INDArray put,Condition condition); - + //TODO: unused / untested method. /** - * Assign the element according - * to the comparison array + * Assign the element according to the comparison array * @param comp the comparison array * @param put the elements to put * @param condition the condition for masking on - * @return + * @return a copy of this array with the conditional assignments. */ INDArray putWhere(Number comp,INDArray put,Condition condition); - + //TODO: unused / untested method. (only used to forward calls from other putWhereWithMask implementations. /** - * Use a pre computed mask - * for assigning arrays + * Use a pre computed mask for assigning arrays * @param mask the mask to use * @param put the array to put - * @return the resulting array + * @return a copy of this array with the conditional assignments. */ INDArray putWhereWithMask(INDArray mask,INDArray put); - + //TODO: unused / untested method. /** - * Use a pre computed mask - * for assigning arrays + * Use a pre computed mask for assigning arrays * @param mask the mask to use * @param put the array to put - * @return the resulting array + * @return a copy of this array with the conditional assignments. */ INDArray putWhereWithMask(INDArray mask,Number put); + //TODO: unused / untested method. /** - * Assign the element according - * to the comparison array + * Assign the element according to the comparison array * @param comp the comparison array * @param put the elements to put * @param condition the condition for masking on - * @return + * @return a copy of this array with the conditional assignments. */ INDArray putWhere(Number comp,Number put,Condition condition); @@ -739,14 +738,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray get(INDArray indices); - /** - * Get the elements from this ndarray based on the specified indices - * @param indices an ndaray of the indices to get the elements for - * @return the elements to get the array for - */ - @Deprecated - INDArray get(List> indices); - /** * Get an INDArray comprised of the specified columns only. Copy operation. * @@ -779,20 +770,20 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rdivi(INDArray other); - + //TODO: unused / untested method. /** * Reverse division * - * @param other the matrix to subtract from + * @param other the matrix to divide from * @param result the result ndarray - * @return + * @return the result ndarray */ INDArray rdiv(INDArray other, INDArray result); /** * Reverse division (in-place) * - * @param other the other ndarray to subtract + * @param other the matrix to divide from * @param result the result ndarray * @return the ndarray with the operation applied */ @@ -803,11 +794,10 @@ public interface INDArray extends Serializable, AutoCloseable { * * @param other the matrix to subtract from * @param result the result ndarray - * @return + * @return the result ndarray */ INDArray rsub(INDArray other, INDArray result); - /** * Element-wise reverse subtraction (copy op). i.e., other - this * @@ -850,22 +840,19 @@ public interface INDArray extends Serializable, AutoCloseable { INDArray assign(boolean value); /** - * Get the linear index of the data in to - * the array + * Get the linear index of the data in to the array * * @param i the index to getScalar * @return the linear index in to the data */ long linearIndex(long i); - + //TODO: unused / untested method. only used recursively. /** - * - * @param list + * Flattens the array for linear indexing in list. */ void sliceVectors(List list); - /** * Assigns the given matrix (put) to the specified slice * @@ -883,16 +870,15 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray cond(Condition condition); - /** * Replicate and tile array to fill out to the given shape - * + * See: + * https://github.com/numpy/numpy/blob/master/numpy/matlib.py#L310-L358 * @param shape the new shape of this ndarray * @return the shape to fill out to */ INDArray repmat(int... shape); - /** * Repeat elements along a specified dimension. * @@ -902,11 +888,9 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray repeat(int dimension, long... repeats); - /** * Insert a row in to this array - * Will throw an exception if this - * ndarray is not a matrix + * Will throw an exception if this ndarray is not a matrix * * @param row the row insert into * @param toPut the row to insert @@ -916,8 +900,7 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Insert a column in to this array - * Will throw an exception if this - * ndarray is not a matrix + * Will throw an exception if this ndarray is not a matrix * * @param column the column to insert * @param toPut the array to put @@ -927,7 +910,6 @@ public interface INDArray extends Serializable, AutoCloseable { /** * Returns the element at the specified row/column - * This will throw an exception if the * * @param row the row of the element to return * @param column the row of the element to return @@ -958,26 +940,20 @@ public interface INDArray extends Serializable, AutoCloseable { */ double distance1(INDArray other); - /** * Put element in to the indices denoted by * the indices ndarray. - * This is equivalent to: + * In numpy this is equivalent to: * a[indices] = element * - * in numpy. - * * @param indices the indices to put * @param element the element array to put * @return this array */ INDArray put(INDArray indices,INDArray element); - - /** - * Put the elements of the ndarray - * in to the specified indices + * Put the elements of the ndarray in to the specified indices * * @param indices the indices to put the ndarray in to * @param element the ndarray to put @@ -986,8 +962,7 @@ public interface INDArray extends Serializable, AutoCloseable { INDArray put(INDArrayIndex[] indices, INDArray element); /** - * Put the elements of the ndarray - * in to the specified indices + * Put the elements of the ndarray in to the specified indices * * @param indices the indices to put the ndarray in to * @param element the ndarray to put @@ -1015,7 +990,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray put(int i, int j, INDArray element); - /** * Inserts the element at the specified index * @@ -1026,7 +1000,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray put(int i, int j, Number element); - /** * Inserts the element at the specified index * @@ -1036,7 +1009,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray put(int i, INDArray element); - /** * In place division of a column vector * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseModule.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseModule.java deleted file mode 100644 index 0826ac5f9..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseModule.java +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.ArrayList; -import java.util.List; - -/** - * Abstract base class for {@link Module} - * that handles Dynamic ops and handles nesting. - * - * This is a logical unit for defining layers - * very similar to pytorch's modules, or tensorflow's layers. - * - * @author Adam Gibson - */ -@NoArgsConstructor -public abstract class BaseModule extends DynamicCustomOp implements Module { - private List modules = new ArrayList<>(); - - public BaseModule(String opName, INDArray[] inputs, INDArray[] outputs, List tArguments, List iArguments, List modules) { - super(opName, inputs, outputs, tArguments, iArguments); - this.modules = modules; - } - - public BaseModule(String opName, SameDiff sameDiff, SDVariable[] args, boolean inPlace, List modules) { - super(opName, sameDiff, args, inPlace); - this.modules = modules; - } - - @Override - public Module[] subModules() { - return modules.toArray(new Module[modules.size()]); - } - - @Override - public void addModule(Module module) { - modules.add(module); - } - - - - -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 7fc0679db..10c26d29e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index f52450eee..d2190098c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -16,9 +16,9 @@ package org.nd4j.linalg.api.ops; -import com.google.common.collect.Lists; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.collect.Lists; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Longs; import lombok.*; import lombok.extern.slf4j.Slf4j; import onnx.Onnx; @@ -35,6 +35,7 @@ import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; +import java.lang.reflect.Array; import java.util.*; /** @@ -611,6 +612,21 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { return in == null ? null : new INDArray[]{in}; } + protected static T[] wrapFilterNull(T... in){ + int count = 0; + for( int i=0; i calculateOutputShape(){ if(inputArguments != null && !inputArguments.isEmpty()){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java index 8e352be13..9ae09a57d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.aggregates; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index 4c0fb2e6b..3f56096a2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; +import java.util.List; import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; @@ -25,6 +26,7 @@ import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -89,4 +91,9 @@ public abstract class BaseCompatOp extends DynamicCustomOp { public Map> attributeAdaptersForFunction() { return super.attributeAdaptersForFunction(); } + + @Override + public List calculateOutputShape() { + throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops."); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java index 85a94eb13..52705ce9e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java @@ -68,15 +68,6 @@ public class Enter extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().getArr().dataType())); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java index f9e358f3c..a7fecf03c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Exit.java @@ -56,15 +56,6 @@ public class Exit extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(arg().getShapeDescriptor()); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java index a3ace4f13..4f5d11b38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/LoopCond.java @@ -37,15 +37,6 @@ public class LoopCond extends BaseCompatOp { return "loop_cond"; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType())); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java index 386f4a075..6cf52e7b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java @@ -64,15 +64,6 @@ public class Merge extends BaseCompatOp { return 60L; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(arg().getShapeDescriptor()); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java index fabd0479b..367a134a3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/NextIteration.java @@ -53,15 +53,6 @@ public class NextIteration extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(arg().getArr() != null) { - return Collections.singletonList(LongShapeDescriptor.fromShape(arg().getShape(), arg().dataType())); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java index 77145a625..e94a7bc54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.Getter; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; @@ -63,18 +63,6 @@ public class Switch extends BaseCompatOp { return OP_NAME; } - @Override - public List calculateOutputShape() { - if(args()[0].getArr() != null) { - val arg0 = args()[0]; - val arr0 = arg0.getArr(); - val dtype = arr0.dataType(); - return Arrays.asList(LongShapeDescriptor.fromShape(arg0.getShape(), dtype),LongShapeDescriptor.fromShape(arg0.getShape(), dtype)); - } - else - return Collections.emptyList(); - } - @Override public SDVariable[] outputVariables() { return super.outputVariables(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java deleted file mode 100644 index 27f357b4b..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java +++ /dev/null @@ -1,198 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.layers; - -import lombok.Builder; -import lombok.NoArgsConstructor; -import lombok.val; -import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.samediff.VariableType; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.blas.params.MMulTranspose; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseModule; -import org.nd4j.linalg.api.ops.Module; -import org.nd4j.linalg.api.ops.impl.reduce.Mmul; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.weightinit.WeightInitScheme; -import org.nd4j.weightinit.impl.ZeroInitScheme; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -/** - * Linear: - * a * bT - * - * @author Adam Gibson - */ -@NoArgsConstructor -public class Linear extends BaseModule { - private DifferentialFunction forward; - private int nIn,nOut; - private WeightInitScheme weightInitScheme,biasWeightInitScheme; - - @Builder(builderMethodName = "execBuilder") - public Linear(int nIn, - int nOut, - WeightInitScheme weightInitScheme, - WeightInitScheme biasWeightInitScheme) { - super(null, - getParams(nIn,nOut,weightInitScheme,biasWeightInitScheme), - new INDArray[]{}, - new ArrayList(), new ArrayList(), new ArrayList()); - this.weightInitScheme = weightInitScheme; - this.biasWeightInitScheme = biasWeightInitScheme; - this.nIn = nIn; - this.nOut = nOut; - } - - @Builder(builderMethodName = "sameDiffBuilder") - public Linear(SameDiff sameDiff, - int nIn, - int nOut, - WeightInitScheme weightInitScheme, - WeightInitScheme biasWeightInitScheme) { - super(null, sameDiff, null, false, new ArrayList()); - this.weightInitScheme = weightInitScheme; - this.biasWeightInitScheme = biasWeightInitScheme; - - this.nIn = nIn; - this.nOut = nOut; - - } - - @Override - public String opName() { - return "linear"; - } - - @Override - public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - - } - - @Override - public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - - } - - @Override - public List doDiff(List f1) { - execSameDiff(); - return forward.doDiff(f1); - } - - @Override - public List calculateOutputShape() { - List ret = new ArrayList<>(); - ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),new long[]{nOut,nIn}), inputArguments()[1].dataType())); - - ret.add(LongShapeDescriptor.fromShape(Shape.getMatrixMultiplyShape(inputArguments()[0].shape(),inputArguments()[1].transpose().shape()), inputArguments()[1].dataType())); - if(biasWeightInitScheme != null) { - ret.add(LongShapeDescriptor.fromShape(new long[]{nOut,1}, inputArguments()[1].dataType())); - } - return ret; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - - @Override - public void exec(INDArray... inputs) { - val inputArguments = inputArguments(); - if(inputArguments == null || inputArguments.length < 1) { - throw new IllegalStateException("No arguments found."); - } - - INDArray weights = inputArguments[0]; - INDArray right = inputArguments[1]; - - val outputArguments = outputArguments(); - - if(outputArguments == null || outputArguments.length < 1) { - if(inputArguments.length == 1) - addOutputArgument(inputs[0].mmul(weights.transpose())); - else - addOutputArgument(inputs[0].mmul(weights.transpose()).addiColumnVector(right)); - - } - else { - inputs[0].mmul(weights.transpose(),outputArguments[0]); - } - - } - - @Override - public void execSameDiff(SDVariable... input) { - val args = args(); - if(args == null || args.length == 0) { - throw new IllegalStateException("No arguments found"); - } - - if(forward == null) { - //bias needs to be added yet - if(args.length > 1) { - /* - forward = f().add(new Mmul(sameDiff, input[0],args()[0], - MMulTranspose.builder() - .transposeA(false) - .transposeB(true) - .build()).outputVariables()[0],args()[1]); - */ - } else { - forward = new Mmul(sameDiff, input[0],args()[0], - MMulTranspose.builder().transposeA(false).transposeB(true).build()); - } - - this.outputVariables = forward.outputVariables(); - } - - - } - - private static INDArray[] getParams(int nIn, - int nOut, - WeightInitScheme paramsScheme, - WeightInitScheme biasInitScheme) { - if(biasInitScheme != null) { - return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn}),biasInitScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,1})}; - } - else { - return new INDArray[] {paramsScheme.create(Nd4j.defaultFloatingPointType(), new long[]{nOut,nIn})}; - - } - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index d3fe330fd..810974103 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -268,14 +268,6 @@ public class Conv3D extends DynamicCustomOp { return ret; } - - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if (numIArguments() < 1) { - addArgs(); - } - } - @Override public boolean isConfigProperties() { return true; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index cfef2a61b..21941de93 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -38,11 +38,6 @@ public class L2Loss extends DynamicCustomOp { super(sameDiff, new SDVariable[]{var}); } - @Override - public List calculateOutputShape() { - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], Nd4j.defaultFloatingPointType())); - } - @Override public String opName() { return "l2_loss"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java index 4ad8fc5d9..cdddee2f6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/HashCode.java @@ -47,11 +47,6 @@ public class HashCode extends DynamicCustomOp { this.outputArguments.add(result); } - @Override - public List calculateOutputShape() { - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.LONG)); - } - @Override public String opName() { return "hashcode"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index 3de44537a..d0c1bae38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.api.ops.impl.reduce; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NoArgsConstructor; import lombok.val; import onnx.Onnx; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index 15673a5a0..98fa587b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -77,8 +78,12 @@ public class RectifiedLinear extends BaseScalarOp { @Override public List doDiff(List i_v) { - SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0]; - SDVariable ret = step.mul(i_v.get(0)); - return Arrays.asList(ret); + if(scalarValue.getDouble(0) == 0.0){ + return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0))); + } else { + SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0]; + SDVariable ret = step.mul(i_v.get(0)); + return Collections.singletonList(ret); + } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java new file mode 100644 index 000000000..3af7a4190 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java @@ -0,0 +1,43 @@ +package org.nd4j.linalg.api.ops.impl.scalar; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.shade.guava.base.Preconditions; + +import java.util.Collections; +import java.util.List; + +public class RectifiedLinearDerivative extends DynamicCustomOp { + + public RectifiedLinearDerivative(){ } + + public RectifiedLinearDerivative(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public RectifiedLinearDerivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "relu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Broadcast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Broadcast.java deleted file mode 100644 index 8e2f6c790..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Broadcast.java +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.shape; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -/** - * Broadcast function - * - * @author Adam Gibson - */ -public class Broadcast extends DynamicCustomOp { - private long[] shape; - public Broadcast(SameDiff sameDiff,SDVariable iX, long[] shape) { - super(null,sameDiff,new SDVariable[]{iX}); - this.shape = shape; - } - - - public Broadcast() {} - - - @Override - public List calculateOutputShape() { - return Arrays.asList(LongShapeDescriptor.fromShape(shape, larg().dataType())); - } - - @Override - public String opName() { - return "broadcast"; - } - - - - @Override - public List doDiff(List i_v) { - throw new UnsupportedOperationException(); - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List calculateOutputDataTypes(List dataTypes){ - return Collections.singletonList(dataTypes.get(0)); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java index 231cf5783..6275ce210 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java @@ -66,11 +66,6 @@ public class ConfusionMatrix extends DynamicCustomOp { //Looks like this is implemented in practice using a large collection of discrete ops - not single TF import op? } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public String opName() { return "confusion_matrix"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java index f01e4a448..c4747a371 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java @@ -98,6 +98,9 @@ public class Eye extends DynamicCustomOp { } protected void addArgs() { + iArguments.clear(); + tArguments.clear(); + addIArgument(numRows); addIArgument(numCols); if(batchDimension != null) { @@ -105,6 +108,8 @@ public class Eye extends DynamicCustomOp { addIArgument(dim); } } + + addTArgument((double) dataType.toInt()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index 31718d337..5613cc85f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -74,7 +74,6 @@ public class Gather extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - } @Override @@ -82,30 +81,6 @@ public class Gather extends DynamicCustomOp { OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); } - - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { -// super.resolvePropertiesFromSameDiffBeforeExecution(); - if (indices != null && numInputArguments() < 2) { - if (numInputArguments() == 0) { - INDArray a = Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT); - if (indices.length > 1) - a = a.reshape(indices.length); - else - a = a.reshape(new int[]{}); - - addInputArgument(args()[0].getArr(), a); - } else if (numInputArguments() == 1) { - addInputArgument(Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT)); - } - - } - - if (numIArguments() < 1) { - addIArgument(jaxis); - } - } - @Override public Map> mappingsForFunction() { Map> ret = new HashMap<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index a11b241e5..1ad13ad40 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -82,12 +82,4 @@ public class Linspace extends DynamicCustomOp { public List doDiff(List gradients){ return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2))); } - - @Override - public List calculateOutputShape(){ - INDArray l = arg(2).getArr(); - if(l == null) - return Collections.emptyList(); - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{l.getLong(0)}, dataType)); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java index aacfa19e1..568b14a44 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java @@ -85,14 +85,6 @@ public class Rank extends DynamicCustomOp { return "Rank"; } - @Override - public List calculateOutputShape() { - List ret = new ArrayList<>(); - ret.add(LongShapeDescriptor.fromShape(new long[]{}, DataType.INT)); - return ret; - } - - @Override public List doDiff(List i_v) { return Collections.singletonList(sameDiff.zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java index 02f8f9445..af8940bf4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java @@ -110,15 +110,6 @@ public class Repeat extends DynamicCustomOp { super.initFromOnnx(node, initWith, attributesForNode, graph); } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if (numOutputArguments() < getDescriptor().getNumOutputs()) { - for (val output : outputVariables()) { - addOutputArgument(output.getArr()); - } - } - } - @Override public String onnxName() { return "Repeat"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java index bfd8f58ec..816654d13 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java @@ -41,6 +41,7 @@ public class Squeeze extends DynamicCustomOp { public Squeeze(SameDiff sameDiff, SDVariable arg, int[] squeezeDims) { super(null, sameDiff, new SDVariable[]{arg}); this.squeezeDims = squeezeDims; + addIArgument(squeezeDims); } @Override @@ -53,14 +54,6 @@ public class Squeeze extends DynamicCustomOp { addIArgument(squeezeDims); } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - super.resolvePropertiesFromSameDiffBeforeExecution(); - if (squeezeDims != null && numIArguments() < squeezeDims.length) { - addIArgument(squeezeDims); - } - } - @Override public String opName() { return "squeeze"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 2de0a29c5..782c70859 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; @@ -65,11 +65,6 @@ public class Transpose extends DynamicCustomOp { public Transpose() { } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public Map> mappingsForFunction() { Map> ret = new LinkedHashMap<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java index 9dd6b6338..f9eb1f95a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java @@ -152,17 +152,4 @@ public class Unstack extends DynamicCustomOp { return out; } - @Override - public List calculateOutputShape(){ - //TEMPORARY workaround for: https://github.com/deeplearning4j/deeplearning4j/issues/7093 - if(inputArguments.size() == 1 && inputArguments.get(0).rank() == 1){ - INDArray arr = inputArguments.get(0); - Preconditions.checkState(jaxis == 0, "Can only unstack along dimension 0 for rank 1 arrays, got axis %s for array %ndShape", jaxis, arr); - LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(new long[0], arr.dataType()); - List out = Arrays.asList(ArrayUtil.nTimes((int)arr.length(), lsd, LongShapeDescriptor.class)); - return out; - } - return super.calculateOutputShape(); - } - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java index bcc4e1844..b027750fc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/BaseTensorOp.java @@ -32,7 +32,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; -public abstract class BaseTensorOp extends DynamicCustomOp { +public abstract class BaseTensorOp extends DynamicCustomOp { public BaseTensorOp(String name, SameDiff sameDiff, SDVariable[] args){ super(name, sameDiff, args); @@ -78,8 +78,7 @@ public abstract class BaseTensorOp extends DynamicCustomOp { @Override public List calculateOutputShape() { - //Not used/not required - return Collections.emptyList(); + throw new UnsupportedOperationException("calculateOutputShape() is not supported for tensor ops."); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java index 276dadcab..0503d377b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java @@ -41,12 +41,6 @@ public class TensorArraySize extends BaseTensorOp { return "tensorarraysizev3"; } - @Override - public List calculateOutputShape() { - // output is scalar only - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[]{}, DataType.LONG)); - } - @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java index aa85461e6..ae4a21df7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BaseDynamicTransformOp.java @@ -46,68 +46,6 @@ public abstract class BaseDynamicTransformOp extends DynamicCustomOp { super(null, inputs, outputs); } - - @Override - public List calculateOutputShape() { - long[] firstArgShape; - long[] secondArgShape; - DataType dtypeZ; - - if(numInputArguments() == 2){ - return super.calculateOutputShape(); //Use c++ shape calc, which also accounts for empty broadcast cases, etc -// firstArgShape = inputArguments.get(0).shape(); -// secondArgShape = inputArguments.get(1).shape(); -// dtypeZ = Shape.pickPairwiseDataType(inputArguments.get(0).dataType(), inputArguments.get(1).dataType()); - } else { - val args = args(); - if (args.length < 2) { - if (args[0] == null || (inputArguments.isEmpty() && args[0].getShape() == null)) { - return Collections.emptyList(); - } - DataType dtypeX = !inputArguments.isEmpty() ? inputArguments.get(0).dataType() : args[0].dataType(); - long[] shape = !inputArguments.isEmpty() ? inputArguments.get(0).shape() : args[0].getShape(); - - return Collections.singletonList(LongShapeDescriptor.fromShape(shape, dtypeX)); - } - - if(inputArguments.size() == 2 && inputArguments.get(0) != null && inputArguments.get(1) != null){ - firstArgShape = inputArguments.get(0).shape(); - secondArgShape = inputArguments.get(1).shape(); - } else { - firstArgShape = args[0].getShape(); - secondArgShape = args[1].getShape(); - } - if (args[0] == null || args[0].getShape() == null) { - return Collections.emptyList(); - } - - if (args[1] == null || args[1].getShape() == null) { - return Collections.emptyList(); - } - - // detecting datatype based on both args - val dtypeX = inputArguments.size() > 0 ? inputArguments.get(0).dataType() : args[0].dataType(); - val dtypeY = inputArguments.size() > 1 ? inputArguments.get(1).dataType() : args[1].dataType(); - dtypeZ = Shape.pickPairwiseDataType(dtypeX, dtypeY); - } - - - - if(Arrays.equals(firstArgShape, secondArgShape)){ - try { - return Collections.singletonList(LongShapeDescriptor.fromShape(firstArgShape, dtypeZ)); - } catch (Throwable e) { - throw new RuntimeException("calculateOutputShape() failed for [" + this.opName() + "]", e); - } - } else { - //Handle broadcast shape: [1,4]+[3,1] = [3,4] - Shape.assertBroadcastable(firstArgShape, secondArgShape, this.getClass()); - val outShape = Shape.broadcastOutputShape(firstArgShape, secondArgShape); - - return Collections.singletonList(LongShapeDescriptor.fromShape(outShape, dtypeZ)); - } - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got input %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java index 8894a87ec..091846db8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/HistogramFixedWidth.java @@ -70,16 +70,6 @@ public class HistogramFixedWidth extends DynamicCustomOp { //No op - just need the inputs } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments.isEmpty()){ - //Num bins is 3rd array - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - - @Override public List doDiff(List f1) { throw new UnsupportedOperationException("Not supported"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index f940ac7a0..72d2823b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -82,17 +82,6 @@ public class Pad extends DynamicCustomOp { //Constant value is resolved just before execution } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3){ - INDArray arr = arg(2).getArr(); - this.tArguments.clear(); - this.tArguments.add(arr.getDouble(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - - @Override public List doDiff(List i_v) { //Pad backprop: it's basically slice op... diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java index 28151a899..0a2ab4f20 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; -import com.google.common.primitives.Doubles; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Doubles; +import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java index 318a7dc02..3a9173654 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class CyclicRShiftBits extends BaseDynamicTransformOp { - public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) { + super(sameDiff, new SDVariable[] {x, shift} ,false); } - public CyclicRShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public CyclicRShiftBits(INDArray input, INDArray shift, INDArray output) { + super(new INDArray[]{input, shift}, new INDArray[]{output}); } - public CyclicRShiftBits(INDArray input, int shift) { - this(input, shift,null); + public CyclicRShiftBits(INDArray input, INDArray shift) { + this(input, shift,input.ulike()); } public CyclicRShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java index b4291c5df..20b6f6955 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class CyclicShiftBits extends BaseDynamicTransformOp { - public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public CyclicShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) { + super(sameDiff, new SDVariable[] {x, shift} ,false); } - public CyclicShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public CyclicShiftBits(INDArray input, INDArray shift, INDArray output) { + super(new INDArray[]{input, shift}, new INDArray[]{output}); } - public CyclicShiftBits(INDArray input, int shift) { - this(input, shift,null); + public CyclicShiftBits(INDArray input, INDArray shift) { + this(input, shift,input.ulike()); } public CyclicShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java index b786602d3..95019c5b2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java @@ -71,18 +71,6 @@ public class EqualTo extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java index 27b2ea189..72a46f111 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java @@ -73,18 +73,6 @@ public class GreaterThan extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java index 48d3953aa..50d1c7c43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java @@ -76,18 +76,6 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java index dde92a947..7c7c34fc5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -41,30 +42,28 @@ import java.util.List; public class LayerNorm extends DynamicCustomOp { private boolean noBias = false; + private boolean channelsFirst; - public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, int... dimensions) { - super(null, sameDiff, new SDVariable[] {input, gain, bias}, false); - Preconditions.checkArgument(bias != null, "LayerNorm: Use constructor without bias argument if bias is null / not available."); + public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { + super(null, sameDiff, wrapFilterNull(input, gain, bias), false); + this.noBias = bias == null; + this.channelsFirst = channelsFirst; setDimensions(dimensions); } - public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, int... dimensions) { - super(null, sameDiff, new SDVariable[] {input, gain}, false); - noBias = true; + public LayerNorm(SameDiff sameDiff, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { + this(sameDiff, input, gain, null, channelsFirst, dimensions); + } + + public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, boolean channelsFirst, int... dimensions) { + super("layer_norm", wrapFilterNull(input, gain, bias), wrapOrNull(result)); + this.noBias = bias == null; + this.channelsFirst = channelsFirst; setDimensions(dimensions); } - public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, int... dimensions) { - super("layer_norm", new INDArray[]{input, gain, bias}, new INDArray[]{result}); - Preconditions.checkArgument(bias != null, "LayerNorm: Use different constructor if bias is null."); - - setDimensions(dimensions); - } - - public LayerNorm(INDArray input, INDArray gain, INDArray result, int... dimensions) { - super("layer_norm", new INDArray[]{input, gain}, new INDArray[]{result}); - noBias = true; - setDimensions(dimensions); + public LayerNorm(INDArray input, INDArray gain, INDArray result, boolean channelsFirst, int... dimensions) { + this(input, gain, null, result, channelsFirst, dimensions); } @Override @@ -73,7 +72,10 @@ public class LayerNorm extends DynamicCustomOp { Preconditions.checkArgument(dimensions.length > 0, "LayerNorm: You have to provide dimensions"); this.dimensions = dimensions; + this.iArguments.clear(); addIArgument(dimensions); + this.bArguments.clear(); + this.bArguments.add(channelsFirst); } @Override @@ -96,9 +98,9 @@ public class LayerNorm extends DynamicCustomOp { public List doDiff(List gradient) { SDVariable[] ret; if(noBias){ - ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), dimensions); + ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions); }else{ - ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), dimensions); + ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions); } return Arrays.asList(ret); } @@ -115,4 +117,8 @@ public class LayerNorm extends DynamicCustomOp { return Collections.singletonList(first); } + @Override + public int numOutputArguments() { + return noBias ? 2 : 3; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java index cfd4fff65..f55db2e50 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -39,33 +40,30 @@ import java.util.List; public class LayerNormBp extends DynamicCustomOp { private boolean noBias = false; + private boolean channelsFirst; - public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int... dimensions) { - super(null, sameDiff, new SDVariable[] {input, gain, bias, gradient}, false); - Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available."); - + public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) { + super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false); + this.noBias = bias == null; + this.channelsFirst = channelsFirst; setDimensions(dimensions); } - public LayerNormBp(INDArray input, INDArray gain, INDArray bias, INDArray grad, INDArray dLdx, INDArray dLdg, INDArray dLdb, int... dimensions) { - super("layer_norm_bp", new INDArray[]{input, gain, bias, grad}, new INDArray[]{dLdx, dLdg, dLdb}); - Preconditions.checkArgument(bias != null, "LayerNormBp: Use constructor without bias argument if bias is null / not available."); - + public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) { + super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb)); + this.noBias = bias == null; + this.channelsFirst = channelsFirst; setDimensions(dimensions); } - public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, int... dimensions) { - super(null, sameDiff, new SDVariable[] {input, gain, gradient}, false); - noBias = true; - setDimensions(dimensions); + public LayerNormBp(SameDiff sameDiff, SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) { + this(sameDiff, input, gain, null, gradient, channelsFirst, dimensions); } - public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, int... dimensions) { - super("layer_norm_bp", new INDArray[]{input, gain, grad}, new INDArray[]{dLdx, dLdg}); - noBias = true; - setDimensions(dimensions); + public LayerNormBp(INDArray input, INDArray gain, INDArray grad, INDArray dLdx, INDArray dLdg, boolean channelsFirst, int... dimensions) { + this(input, gain, null, grad, dLdx, dLdg, null, channelsFirst, dimensions); } @Override @@ -74,7 +72,10 @@ public class LayerNormBp extends DynamicCustomOp { Preconditions.checkArgument(dimensions.length > 0, "LayerNormBp: You have to provide dimensions"); this.dimensions = dimensions; + this.iArguments.clear(); addIArgument(dimensions); + this.bArguments.clear(); + addBArgument(channelsFirst); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java index 0ee59458c..4c345070e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java @@ -73,18 +73,6 @@ public class LessThan extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java index 56a5882db..89c08fe65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java @@ -71,18 +71,6 @@ public class LessThanOrEqual extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java index b1f0dadbb..62d3bedfa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java @@ -72,18 +72,6 @@ public class NotEqualTo extends BaseDynamicTransformOp { return Arrays.asList(sameDiff.zerosLike(args()[0]), sameDiff.zerosLike(args()[1])); } - @Override - public List calculateOutputShape() { - if (args() == null) - return Collections.emptyList(); - - if (inputArguments.size() == 0) - return Collections.emptyList(); - - - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java index 80697efa3..4435615f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class RShiftBits extends BaseDynamicTransformOp { - public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public RShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); } - public RShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public RShiftBits(INDArray input, INDArray shift, INDArray output) { + super(new INDArray[]{input, shift}, new INDArray[]{output}); } - public RShiftBits(INDArray input, int shift) { - this(input, shift,null); + public RShiftBits(INDArray input, INDArray shift) { + this(input, shift,input.ulike()); } public RShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java index 8c652f72d..5501324f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class ShiftBits extends BaseDynamicTransformOp { - public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public ShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); } - public ShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public ShiftBits(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); } - public ShiftBits(INDArray input, int shift) { - this(input, shift,null); + public ShiftBits(INDArray x, INDArray y) { + this(x, y,x.ulike()); } public ShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index 01408ea4d..98a479542 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -125,41 +125,6 @@ public class Cast extends BaseDynamicTransformOp { return ret; } - @Override - public List calculateOutputShape() { - if(inputArguments.size() > 0){ - long[] s = inputArguments.get(0).shape(); - LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst); - if(inputArguments.get(0).isEmpty()){ - long e = lsd.getExtras(); - e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY); - lsd.setExtras(e); - } - return Collections.singletonList(lsd); - } - - if (arg() != null && (arg().getArr() != null || arg().getShape() != null)) { - if (arg().getArr() != null) { - long[] s = arg().getArr().shape(); - LongShapeDescriptor lsd = LongShapeDescriptor.fromShape(s, typeDst); - if(inputArguments.size() > 0 && inputArguments.get(0) != null && inputArguments.get(0).isEmpty()){ - long e = lsd.getExtras(); - e = ArrayOptionsHelper.setOptionBit(e, ArrayType.EMPTY); - lsd.setExtras(e); - } - return Collections.singletonList(lsd); - } else { - long[] s = arg().getShape(); - if(Shape.isPlaceholderShape(s)){ - return Collections.emptyList(); - } - return Collections.singletonList(LongShapeDescriptor.fromShape(s, typeDst)); - } - } - - return Collections.emptyList(); - } - @Override public void setValueFor(Field target, Object value) { //This is a hack around a property mapping issue - TF datatype DT_DOUBLE return attribute.getType() of DT_DOUBLE which doesn't make sense diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java index d61f7a067..10abb69a2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/GradientBackwardsMarker.java @@ -77,7 +77,7 @@ public class GradientBackwardsMarker extends DynamicCustomOp { @Override public List calculateOutputShape() { - return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.FLOAT)); + throw new UnsupportedOperationException("calculateOutputShape() is not supported for control flow ops."); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java index dcbcc271f..44648cc4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/BaseArithmeticBackpropOp.java @@ -46,11 +46,6 @@ public abstract class BaseArithmeticBackpropOp extends BaseDynamicTransformOp { throw new UnsupportedOperationException("Not supported"); } - @Override - public List calculateOutputShape(){ - return Nd4j.getExecutioner().calculateOutputShape(this); - } - @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got input %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index 3e548226f..4b5825d8f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -79,11 +79,4 @@ public class Identity extends BaseDynamicTransformOp { return dataTypes; } - @Override - public List calculateOutputShape() { - if(inputArguments == null || inputArguments.isEmpty()) - return Collections.emptyList(); - return Collections.singletonList(inputArguments.get(0).shapeDescriptor()); - } - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 73d4d281e..6d6798701 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -54,14 +54,6 @@ public class UnsortedSegmentMax extends DynamicCustomOp { return "UnsortedSegmentMax"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index a110d7225..f51b94218 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -54,14 +54,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp { return "UnsortedSegmentMean"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 6666981ff..1b885676e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -54,14 +54,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp { return "UnsortedSegmentMin"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index 36ba70342..b2e254fb7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -54,14 +54,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp { return "UnsortedSegmentProd"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index cfb62450a..ef34e9f81 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -53,14 +53,6 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { return "UnsortedSegmentSqrtN"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index 3e4c4d2f1..466cc8cf2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -55,14 +55,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp { return "UnsortedSegmentSum"; } - @Override - public void resolvePropertiesFromSameDiffBeforeExecution() { - if(args().length == 3 && iArguments == null || iArguments.size() == 0){ - addIArgument(arg(2).getArr().getInt(0)); - } - super.resolvePropertiesFromSameDiffBeforeExecution(); - } - @Override public List doDiff(List gradients){ return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index bbe133dbb..ef4331b0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -17,8 +17,8 @@ package org.nd4j.linalg.api.shape; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; import lombok.val; import org.nd4j.base.Preconditions; @@ -31,8 +31,6 @@ import org.nd4j.linalg.api.shape.options.ArrayType; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; -import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import java.nio.*; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java index 6039e4b1a..f3ac98565 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java @@ -181,37 +181,37 @@ public class NDArrayCreationUtil { INDArray[] out = new INDArray[12]; INDArray temp01 = Nd4j.linspace(1, cols * rows * 4, cols * rows * 4, dataType).reshape(cols, rows, 4); - out[0] = temp01.javaTensorAlongDimension(0, 0, 1).reshape(rows, cols); + out[0] = temp01.tensorAlongDimension(0, 0, 1).reshape(rows, cols); long[] temp01Shape = new long[] {cols, rows, 4}; int len = ArrayUtil.prod(temp01Shape); temp01 = Nd4j.linspace(1, len, len, dataType).reshape(temp01Shape); - out[1] = temp01.javaTensorAlongDimension(2, 0, 1).reshape(rows, cols); + out[1] = temp01.tensorAlongDimension(2, 0, 1).reshape(rows, cols); Nd4j.getRandom().setSeed(seed); INDArray temp02 = Nd4j.linspace(1, len, len, dataType).reshape(new long[] {cols, 4, rows}); - out[2] = temp02.javaTensorAlongDimension(0, 0, 2).reshape(rows, cols); + out[2] = temp02.tensorAlongDimension(0, 0, 2).reshape(rows, cols); temp02 = Nd4j.linspace(1, len, len, dataType).reshape(cols, 4, rows); - out[3] = temp02.javaTensorAlongDimension(2, 0, 2).reshape(rows, cols); + out[3] = temp02.tensorAlongDimension(2, 0, 2).reshape(rows, cols); INDArray temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4); - out[4] = temp10.javaTensorAlongDimension(0, 1, 0).reshape(rows, cols); + out[4] = temp10.tensorAlongDimension(0, 1, 0).reshape(rows, cols); temp10 = Nd4j.linspace(1, len, len, dataType).reshape(rows, cols, 4); - out[5] = temp10.javaTensorAlongDimension(2, 1, 0).reshape(rows, cols); + out[5] = temp10.tensorAlongDimension(2, 1, 0).reshape(rows, cols); INDArray temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows); - out[6] = temp12.javaTensorAlongDimension(0, 1, 2).reshape(rows, cols); + out[6] = temp12.tensorAlongDimension(0, 1, 2).reshape(rows, cols); temp12 = Nd4j.linspace(1, len, len, dataType).reshape(4, cols, rows); - out[7] = temp12.javaTensorAlongDimension(2, 1, 2).reshape(rows, cols); + out[7] = temp12.tensorAlongDimension(2, 1, 2).reshape(rows, cols); INDArray temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols); - out[8] = temp20.javaTensorAlongDimension(0, 2, 0).reshape(rows, cols); + out[8] = temp20.tensorAlongDimension(0, 2, 0).reshape(rows, cols); temp20 = Nd4j.linspace(1, len, len, dataType).reshape(rows, 4, cols); - out[9] = temp20.javaTensorAlongDimension(2, 2, 0).reshape(rows, cols); + out[9] = temp20.tensorAlongDimension(2, 2, 0).reshape(rows, cols); INDArray temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols); - out[10] = temp21.javaTensorAlongDimension(0, 2, 1).reshape(rows, cols); + out[10] = temp21.tensorAlongDimension(0, 2, 1).reshape(rows, cols); temp21 = Nd4j.linspace(1, len, len, dataType).reshape(4, rows, cols); - out[11] = temp21.javaTensorAlongDimension(2, 2, 1).reshape(rows, cols); + out[11] = temp21.tensorAlongDimension(2, 2, 1).reshape(rows, cols); String baseMsg = "getTensorAlongDimensionMatricesWithShape(" + rows + "," + cols + "," + seed + ")"; List> list = new ArrayList<>(12); @@ -361,9 +361,9 @@ public class NDArrayCreationUtil { val shape4d1 = new long[]{shape[0], shape[1], shape[2], 3}; int lenshape4d1 = ArrayUtil.prod(shape4d1); INDArray orig1a = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1); - INDArray tad1a = orig1a.javaTensorAlongDimension(0, 0, 1, 2); + INDArray tad1a = orig1a.tensorAlongDimension(0, 0, 1, 2); INDArray orig1b = Nd4j.linspace(1, lenshape4d1, lenshape4d1, dataType).reshape(shape4d1); - INDArray tad1b = orig1b.javaTensorAlongDimension(1, 0, 1, 2); + INDArray tad1b = orig1b.tensorAlongDimension(1, 0, 1, 2); list.add(new Pair<>(tad1a, baseMsg + ".get(0)")); list.add(new Pair<>(tad1b, baseMsg + ".get(1)")); @@ -371,19 +371,19 @@ public class NDArrayCreationUtil { long[] shape4d2 = {3, shape[0], shape[1], shape[2]}; int lenshape4d2 = ArrayUtil.prod(shape4d2); INDArray orig2 = Nd4j.linspace(1, lenshape4d2, lenshape4d2, dataType).reshape(shape4d2); - INDArray tad2 = orig2.javaTensorAlongDimension(1, 1, 2, 3); + INDArray tad2 = orig2.tensorAlongDimension(1, 1, 2, 3); list.add(new Pair<>(tad2, baseMsg + ".get(2)")); long[] shape4d3 = {shape[0], shape[1], 3, shape[2]}; int lenshape4d3 = ArrayUtil.prod(shape4d3); INDArray orig3 = Nd4j.linspace(1, lenshape4d3, lenshape4d3, dataType).reshape(shape4d3); - INDArray tad3 = orig3.javaTensorAlongDimension(1, 1, 3, 0); + INDArray tad3 = orig3.tensorAlongDimension(1, 1, 3, 0); list.add(new Pair<>(tad3, baseMsg + ".get(3)")); long[] shape4d4 = {shape[0], 3, shape[1], shape[2]}; int lenshape4d4 = ArrayUtil.prod(shape4d4); INDArray orig4 = Nd4j.linspace(1, lenshape4d4, lenshape4d4, dataType).reshape(shape4d4); - INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3); + INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3); list.add(new Pair<>(tad4, baseMsg + ".get(4)")); return list; @@ -513,9 +513,9 @@ public class NDArrayCreationUtil { int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3]}; int len = ArrayUtil.prod(shape4d1); INDArray orig1a = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1)); - INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4); + INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4); INDArray orig1b = Nd4j.linspace(1, len, len, dataType).reshape(ArrayUtil.toLongArray(shape4d1)); - INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4); + INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4); list.add(new Pair<>(tad1a, baseMsg + ".get(0)")); list.add(new Pair<>(tad1b, baseMsg + ".get(1)")); @@ -523,19 +523,19 @@ public class NDArrayCreationUtil { int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3]}; int len2 = ArrayUtil.prod(shape4d2); INDArray orig2 = Nd4j.linspace(1, len2, len2, dataType).reshape(ArrayUtil.toLongArray(shape4d2)); - INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 4, 2, 1); + INDArray tad2 = orig2.tensorAlongDimension(1, 3, 4, 2, 1); list.add(new Pair<>(tad2, baseMsg + ".get(2)")); int[] shape4d3 = {shape[0], shape[1], 3, shape[2], shape[3]}; int len3 = ArrayUtil.prod(shape4d3); INDArray orig3 = Nd4j.linspace(1, len3, len3, dataType).reshape(ArrayUtil.toLongArray(shape4d3)); - INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 0); + INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 0); list.add(new Pair<>(tad3, baseMsg + ".get(3)")); int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3}; int len4 = ArrayUtil.prod(shape4d4); INDArray orig4 = Nd4j.linspace(1, len4, len4, dataType).reshape(ArrayUtil.toLongArray(shape4d4)); - INDArray tad4 = orig4.javaTensorAlongDimension(1, 2, 0, 3, 1); + INDArray tad4 = orig4.tensorAlongDimension(1, 2, 0, 3, 1); list.add(new Pair<>(tad4, baseMsg + ".get(4)")); return list; @@ -655,26 +655,26 @@ public class NDArrayCreationUtil { Nd4j.getRandom().setSeed(seed); int[] shape4d1 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]}; INDArray orig1a = Nd4j.rand(dataType, shape4d1); - INDArray tad1a = orig1a.javaTensorAlongDimension(0, 1, 2, 3, 4, 5); + INDArray tad1a = orig1a.tensorAlongDimension(0, 1, 2, 3, 4, 5); INDArray orig1b = Nd4j.rand(dataType, shape4d1); - INDArray tad1b = orig1b.javaTensorAlongDimension(2, 1, 2, 3, 4, 5); + INDArray tad1b = orig1b.tensorAlongDimension(2, 1, 2, 3, 4, 5); list.add(new Pair<>(tad1a, baseMsg + ".get(0)")); list.add(new Pair<>(tad1b, baseMsg + ".get(1)")); int[] shape4d2 = {3, shape[0], shape[1], shape[2], shape[3], shape[4]}; INDArray orig2 = Nd4j.rand(dataType, shape4d2); - INDArray tad2 = orig2.javaTensorAlongDimension(1, 3, 5, 4, 2, 1); + INDArray tad2 = orig2.tensorAlongDimension(1, 3, 5, 4, 2, 1); list.add(new Pair<>(tad2, baseMsg + ".get(2)")); int[] shape4d3 = {shape[0], shape[1], shape[2], shape[3], shape[4], 2}; INDArray orig3 = Nd4j.rand(dataType, shape4d3); - INDArray tad3 = orig3.javaTensorAlongDimension(1, 4, 1, 3, 2, 0); + INDArray tad3 = orig3.tensorAlongDimension(1, 4, 1, 3, 2, 0); list.add(new Pair<>(tad3, baseMsg + ".get(3)")); int[] shape4d4 = {shape[0], shape[1], shape[2], shape[3], 3, shape[4]}; INDArray orig4 = Nd4j.rand(dataType, shape4d4); - INDArray tad4 = orig4.javaTensorAlongDimension(1, 5, 2, 0, 3, 1); + INDArray tad4 = orig4.tensorAlongDimension(1, 5, 2, 0, 3, 1); list.add(new Pair<>(tad4, baseMsg + ".get(4)")); return list; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java index 0636cc75b..8eff5cf5d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.dataset; -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; +import org.nd4j.shade.guava.collect.Lists; +import org.nd4j.shade.guava.collect.Maps; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java index 2ae7a32c4..cb354c2d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java @@ -16,8 +16,7 @@ package org.nd4j.linalg.dataset; -import com.google.common.base.Function; -import com.google.common.collect.Lists; +import org.nd4j.shade.guava.collect.Lists; import lombok.extern.slf4j.Slf4j; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; @@ -26,8 +25,6 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Condition; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.FeatureUtil; @@ -374,7 +371,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet { long nTensors = labels.tensorsAlongDimension(1); for (int i = 0; i < nTensors; i++) { INDArray row = labels.tensorAlongDimension(i, 1); - INDArray javaRow = labels.javaTensorAlongDimension(i, 1); + INDArray javaRow = labels.tensorAlongDimension(i, 1); int maxIdx = Nd4j.getBlasWrapper().iamax(row); int maxIdxJava = Nd4j.getBlasWrapper().iamax(javaRow); if (maxIdx < 0) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java index a5a46dbcd..5f6731c26 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java @@ -16,12 +16,10 @@ package org.nd4j.linalg.dataset.api; -import com.google.common.base.Function; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.indexing.conditions.Condition; import java.io.File; import java.io.InputStream; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 487196912..1edf0d651 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -1281,7 +1281,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { public INDArray scalar(Number value) { MemoryWorkspace ws = Nd4j.getMemoryManager().getCurrentWorkspace(); - if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends com.google.common.util.concurrent.AtomicDouble */ + if (value instanceof Double || value instanceof AtomicDouble) /* note that org.nd4j.linalg.primitives.AtomicDouble extends org.nd4j.shade.guava.util.concurrent.AtomicDouble */ return scalar(value.doubleValue()); else if (value instanceof Float) return scalar(value.floatValue()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index eb3250e3c..0bbf69ebe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.factory; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.NonNull; import lombok.val; import lombok.var; @@ -103,7 +103,6 @@ import java.text.NumberFormat; import java.text.ParseException; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; @@ -153,7 +152,6 @@ public class Nd4j { public static RandomFactory randomFactory; private static MemoryWorkspaceManager workspaceManager; private static DeallocatorService deallocatorService; - private static final AtomicInteger numThreads = new AtomicInteger(-1); private static AtomicReference defaultFloatingPointDataType; private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE; @@ -2267,15 +2265,12 @@ public class Nd4j { Preconditions.checkState(data.length == numColumns, "Data has inconsistent number of columns: data length %s, numColumns %s", data.length, numColumns); data2.add(readSplit(data)); - - } - ret = Nd4j.create(dataType, data2.size(), numColumns); - for (int i = 0; i < data2.size(); i++) { - float[] row = data2.get(i); - INDArray arr = Nd4j.create(row, new long[]{1, row.length}, dataType); - ret.putRow(i, arr); + float[][] fArr = new float[data2.size()][0]; + for(int i=0; i(new Pointer[]{null, stream}); nativeOps.convertTypes(p, typeSrc.ordinal(), source, length, typeDst.ordinal(), target); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } @Override @@ -1277,7 +1111,13 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { srcPtr = nativeOps.mallocDevice(ssize, 0, 0); dstPtr = nativeOps.mallocDevice(size, 0, 0); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + nativeOps.memcpyAsync(srcPtr, source, ssize, CudaConstants.cudaMemcpyHostToDevice, stream); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } else { // decompressing throw new UnsupportedOperationException(); @@ -1288,9 +1128,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { stream.synchronize(); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + if (buffer instanceof CompressedDataBuffer) { nativeOps.freeDevice(srcPtr, 0); nativeOps.freeDevice(dstPtr, 0); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } } @@ -1309,13 +1155,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { val size = ((CompressedDataBuffer) source).getCompressionDescriptor().getCompressedLength(); srcPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false); nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } // if true - we're compressing into host memory if (target instanceof CompressedDataBuffer) { val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength(); dstPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false); - //nativeOps.memcpyAsync(dstPtr, target.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream); } } else { // if true - we're decompressing from host memory @@ -1325,6 +1173,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { srcPtr = nativeOps.mallocDevice(size, 0, 0); nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream); stream.synchronize(); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } else srcPtr = AtomicAllocator.getInstance().getPointer(source); @@ -1333,8 +1184,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { log.info("Replacing target ptr"); val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength(); dstPtr = nativeOps.mallocDevice(size, 0, 0); - //nativeOps.memcpyAsync(dstPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream); - //stream.synchronize(); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } else dstPtr = AtomicAllocator.getInstance().getPointer(target); } @@ -1342,6 +1194,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { convertDataEx(typeSrc, srcPtr, typeDst, dstPtr, target.length()); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + Nd4j.getExecutioner().commit(); @@ -1364,6 +1219,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + Nd4j.getExecutioner().commit(); } @@ -1462,6 +1320,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)) ); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + AtomicAllocator.getInstance().getFlowController().registerActionAllWrite(context, result); AtomicAllocator.getInstance().getFlowController().registerAction(context,null, result); @@ -1517,6 +1378,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { descending ); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); AtomicAllocator.getInstance().getFlowController().registerAction(context, x); @@ -1565,6 +1428,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { descending ); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); AtomicAllocator.getInstance().getFlowController().registerAction(context, x); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 789f0f1a3..c5b02a82f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -207,6 +207,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new UnsupportedOperationException("Unknown op type: " + op.getOpType()); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + + AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); profilingConfigurableHookOut(op, st); @@ -222,6 +226,21 @@ public class CudaExecutioner extends DefaultOpExecutioner { */ protected INDArray naiveExec(ReduceOp op, int... dimension) { long st = profilingConfigurableHookIn(op); + + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ + //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] + //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" + if(op.z() != null){ + Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); + op.z().assign(op.x()); + return op.z(); + } else { + op.setZ(op.x().dup()); + return op.z(); + } + } + INDArray ret = op.z(); checkForCompression(op); @@ -461,6 +480,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + profilingConfigurableHookOut(op, st); return op.z(); @@ -475,6 +497,20 @@ public class CudaExecutioner extends DefaultOpExecutioner { public INDArray exec(ReduceOp op) { checkForCompression(op); + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ + //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] + //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" + if(op.z() != null){ + Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); + op.z().assign(op.x()); + return op.z(); + } else { + op.setZ(op.x().dup()); + return op.z(); + } + } + val dimension = op.dimensions().toIntVector(); if (extraz.get() == null) @@ -619,7 +655,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null); - + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); @@ -777,6 +814,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new UnsupportedOperationException("Unknown opType: " + op.getOpType()); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); profilingConfigurableHookOut(op, st); @@ -854,13 +894,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { //long dimensionPointer = AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); Pointer dimensionPointer = AtomicAllocator.getInstance() - .getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); + .getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension)); nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(), null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, extraArgs, null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, + dimensionPointer, (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null); @@ -868,6 +908,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y()); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + profilingConfigurableHookOut(op, st); return null; @@ -876,6 +919,22 @@ public class CudaExecutioner extends DefaultOpExecutioner { protected CudaContext invoke(ReduceOp op, int[] dimension) { + CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ + //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] + //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" + if(op.z() != null){ + Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); + op.z().assign(op.x()); + return context; + } else { + op.setZ(op.x().dup()); + return context; + } + } + long st = profilingConfigurableHookIn(op); checkForCompression(op); @@ -899,17 +958,15 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); - if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - val tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); + val tadBuffers = op.x().isEmpty() ? Pair.makePair(op.x().data(), null) : tadManager.getTADOnlyShapeInfo(op.x(), dimension); val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); - val offsets = tadBuffers.getSecond(); + val offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context); @@ -1105,6 +1162,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); profilingConfigurableHookOut(op, st); @@ -1194,6 +1253,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new UnsupportedOperationException(); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); profilingConfigurableHookOut(op, st); @@ -1268,6 +1330,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new UnsupportedOperationException("Unknown op type: " + op.getOpType()); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.scalar()); profilingConfigurableHookOut(op, st); @@ -1423,6 +1488,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); @@ -1582,6 +1649,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context), FlatBuffersMapper.getDataTypeAsByte(dataType)); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + surfacePoint.tickHostWrite(); } @@ -1676,6 +1746,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { numIndexArguments, iPtr, numIntArrays, AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType)); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } /** @@ -1739,6 +1812,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); profilingConfigurableHookOut(op, st); @@ -1969,6 +2045,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { nativeOps.decodeThreshold(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, AtomicAllocator.getInstance().getPointer(result), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer())); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite(); return target; @@ -2013,7 +2092,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { (IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context), (float) threshold); - + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray); @@ -2039,6 +2119,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer())); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); AtomicAllocator.getInstance().getFlowController().registerAction(context, target); @@ -2151,6 +2233,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + if (ptrptr == null) throw new RuntimeException(); @@ -2221,109 +2306,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { } catch (Exception e) { throw new RuntimeException("Op [" + name + "] execution failed", e); } - - /* - long st = profilingConfigurableHookIn(op); - - CudaContext context =(CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - //AtomicAllocator.getInstance().getFlowController().prepareActionAllWrite(op.outputArguments()); - - if (extraz.get() == null) - extraz.set(new PointerPointer(32)); - - - PointerPointer extras = extraz.get().put( - new CudaPointer(1), - context.getOldStream(), - context.getBufferScalar(), - context.getBufferReduction()); - - val outputArgs = op.outputArguments(); - val inputArgs = op.inputArguments(); - - if (outputArgs.length == 0 && !op.isInplaceCall()) - throw new ND4JIllegalStateException("You can't execute non-inplace CustomOp without outputs being specified"); - - val lc = op.opName().toLowerCase(); - val hash = op.opHash(); - - - val inputShapes = new PointerPointer<>(inputArgs.length * 2); - val inputBuffers = new PointerPointer<>(inputArgs.length * 2); - - int cnt= 0; - for (val in: inputArgs) { - val hp = AtomicAllocator.getInstance().getHostPointer(in.shapeInfoDataBuffer()); - inputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(in)); - inputShapes.put(cnt, hp); - - - val dp = AtomicAllocator.getInstance().getPointer(in.shapeInfoDataBuffer(), context); - - inputBuffers.put(cnt + inputArgs.length, AtomicAllocator.getInstance().getPointer(in, context)); - inputShapes.put(cnt+ inputArgs.length, dp); - - if (op.isInplaceCall()) { - val ap = AtomicAllocator.getInstance().getAllocationPoint(in); - if (ap != null) - ap.tickHostWrite(); - } - - cnt++; - } - - - val outputShapes = new PointerPointer<>(outputArgs.length * 2); - val outputBuffers = new PointerPointer<>(outputArgs.length * 2); - - cnt= 0; - for (val out: outputArgs) { - outputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(out)); - outputShapes.put(cnt, AtomicAllocator.getInstance().getHostPointer(out.shapeInfoDataBuffer())); - - outputBuffers.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out, context)); - outputShapes.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out.shapeInfoDataBuffer(), context)); - - val ap = AtomicAllocator.getInstance().getAllocationPoint(out); - - if (ap != null) - ap.tickHostWrite(); - - cnt++; - } - - val iArgs = op.iArgs().length > 0 ? new LongPointer(op.iArgs().length) : null; - - cnt = 0; - for (val i: op.iArgs()) - iArgs.put(cnt++, i); - - - val tArgs = op.tArgs().length > 0 ? new DoublePointer(op.tArgs().length) : null; - - val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.numBArguments()) : null; - - cnt = 0; - for (val t: op.tArgs()) - tArgs.put(cnt++, t); - - cnt = 0; - for (val b: op.bArgs()) - bArgs.put(cnt++, b); - - try { - val status = OpStatus.byNumber(nativeOps.execCustomOp(extras, hash, inputBuffers, inputShapes, inputArgs.length, outputBuffers, outputShapes, outputArgs.length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), op.isInplaceCall())); - if (status != OpStatus.ND4J_STATUS_OK) - throw new ND4JIllegalStateException("Op execution failed: " + status); - } catch (Exception e) { - throw new RuntimeException("Op [" + op.opName() + "] execution failed"); - } - - //AtomicAllocator.getInstance().getFlowController().prepareActionAllWrite(op.outputArguments()); - - profilingConfigurableHookOut(op, st); - return op.outputArguments(); - */ } @Override @@ -2341,6 +2323,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void registerGraph(long id, Pointer graph) { nativeOps.registerGraph(null, id, graph); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } @Override @@ -2368,6 +2353,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result)); if (status != OpStatus.ND4J_STATUS_OK) @@ -2398,6 +2386,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { newMap.put(nodeName, array); } + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + nativeOps.deleteVariablesSet(result); return newMap; @@ -2406,6 +2397,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public void forgetGraph(long id) { nativeOps.unregisterGraph(null, id); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } /** @@ -2474,6 +2468,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getSecond()), null, (IntPointer) AtomicAllocator.getInstance().getPointer(indices, context)); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + AtomicAllocator.getInstance().getFlowController().registerAction(context, array, indices, updates); } @@ -2490,9 +2487,14 @@ public class CudaExecutioner extends DefaultOpExecutioner { ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer()); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + if (status != 0) throw new RuntimeException("Op [" + op.opName() + "] execution failed"); + + for (val arr:op.outputArguments()) AtomicAllocator.getInstance().registerAction(ctx, arr); @@ -2527,6 +2529,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { nativeOps.inspectArray(extras, AtomicAllocator.getInstance().getHostPointer(array), (LongPointer) AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(array, ctx), (LongPointer) AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()), debugInfo); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + return INDArrayStatistics.builder() .minValue(debugInfo._minValue()) .maxValue(debugInfo._maxValue()) @@ -2545,6 +2550,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length)); nativeOps.deleteShapeBuffer(dbf); @@ -2556,6 +2564,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack)); val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack)); @@ -2568,6 +2579,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType); buffer.setConstant(true); @@ -2578,6 +2592,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType); buffer.setConstant(true); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 26d363f32..cf779f537 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -100,7 +100,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { @Override public Pointer contextPointer() { for (val v:fastpath_in.values()) { - if (v.isEmpty()) + if (v.isEmpty() || v.isS()) continue; AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); @@ -111,7 +111,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { } for (val v:fastpath_out.values()) { - if (v.isEmpty()) + if (v.isEmpty() || v.isS()) continue; AtomicAllocator.getInstance().getAllocationPoint(v).tickHostRead(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java index 6c8ea85e1..1922d9ced 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/util/CudaArgs.java @@ -17,8 +17,8 @@ package org.nd4j.linalg.jcublas.util; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.Multimap; +import org.nd4j.shade.guava.collect.ArrayListMultimap; +import org.nd4j.shade.guava.collect.Multimap; import lombok.AllArgsConstructor; import lombok.Data; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 47cfa2584..15f6c52ef 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -449,6 +449,60 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { // #endif //DEV_TESTS_TADPACK_H +// Parsed from execution/ErrorReference.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef DEV_TESTS_ERRORREFERENCE_H +// #define DEV_TESTS_ERRORREFERENCE_H + +// #include +// #include + @Namespace("sd") @NoOffset public static class ErrorReference extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ErrorReference(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ErrorReference(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ErrorReference position(long position) { + return (ErrorReference)super.position(position); + } + + public ErrorReference() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native int errorCode(); + public native @Cast("char*") String errorMessage(); + + public native void setErrorCode(int errorCode); + public native void setErrorMessage(@StdString BytePointer message); + public native void setErrorMessage(@StdString String message); + } + + + +// #endif //DEV_TESTS_ERRORREFERENCE_H + + // Parsed from memory/MemoryType.h // @@ -688,6 +742,18 @@ bool verbose = false; // #include // #include +/** + * This function returns last error code stored, + * @return non-zero if something bad happened + */ +public native int lastErrorCode(); + +/** + * This function returns last error message, if last error code > 0 + * @return + */ +public native @Cast("char*") String lastErrorMessage(); + /** * * @param p @@ -1710,72 +1776,6 @@ public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraP @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); - -/** -* Append an input array -* to the end of a flat array -* in a particular order -* @param offset the offset of the array to start at -* @param order the order -* @param result the result array -* @param resultShapeInfo the shape info for te array -* @param input the input for the array -* @param inputShapeInfo the shape information for that array -*/ -public native void flatten( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int offset, - char order, - Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo, - Pointer input, @Cast("Nd4jLong*") LongPointer inputShapeInfo, - Pointer dinput, @Cast("Nd4jLong*") LongPointer dinputShapeInfo); -public native void flatten( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int offset, - char order, - Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo, - Pointer input, @Cast("Nd4jLong*") LongBuffer inputShapeInfo, - Pointer dinput, @Cast("Nd4jLong*") LongBuffer dinputShapeInfo); -public native void flatten( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int offset, - char order, - Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo, - Pointer input, @Cast("Nd4jLong*") long[] inputShapeInfo, - Pointer dinput, @Cast("Nd4jLong*") long[] dinputShapeInfo); - -public native void concat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - @Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo, - Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void concat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - @Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo, - Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void concat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - @Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo, - Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers); - - public native void specialConcat( @Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, @@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc // #include // #include // #include +// #include // #include // #include // #include @@ -9950,6 +9951,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include +// #include @Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -9985,6 +9987,8 @@ public static final int PREALLOC_SIZE = 33554432; public native void setScalarBuffer(Pointer pointer); public native void setAllocationBuffer(Pointer pointer); + public native ErrorReference errorReference(); + public native void triggerOwnership(@Cast("bool") boolean isOwner); public native int deviceId(); @@ -10038,6 +10042,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } @@ -10067,9 +10072,12 @@ public static final int PREALLOC_SIZE = 33554432; public native int getDeviceID(); public native void setDeviceID(int deviceID); + public native ErrorReference errorReference(); public static native @Cast("bool") boolean isInitialized(); public static native void releaseBuffers(); + + public static native LaunchContext defaultContext(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 4466cf4b5..8f95fe5cb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -32,6 +32,7 @@ import org.bytedeco.javacpp.tools.InfoMapper; "array/ConstantDescriptor.h", "array/ConstantDataBuffer.h", "array/TadPack.h", + "execution/ErrorReference.h", "memory/MemoryType.h", "Environment.h", "types/utf8string.h", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 2b47103c3..cacf32b38 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -106,6 +106,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { functions.put(8, Loader.addressof("LAPACKE_sgesdd")); functions.put(9, Loader.addressof("LAPACKE_dgesdd")); nativeOps.initializeFunctions(functions); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } @Override @@ -489,32 +492,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { @Override public INDArray toFlattened(char order, Collection matrices) { Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands"); -/* - int length = 0; - val list = new ArrayList(matrices); - val t = list.get(0).dataType(); - for (INDArray m : matrices) { - length += m.length(); - Preconditions.checkArgument(m.dataType() == t, "All operands must have same data type"); - } - INDArray ret = Nd4j.create(t, new long[] {length}, order); - int linearIndex = 0; - PointerPointer dummy = new PointerPointer(new Pointer[] {null}); - for (INDArray m : matrices) { - Nd4j.getCompressor().autoDecompress(m); - - nativeOps.flatten(dummy, linearIndex, order, - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null, - m.data().addressPointer(), - (LongPointer) m.shapeInfoDataBuffer().addressPointer(), - null, null); - - linearIndex += m.length(); - } - return ret; - */ return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0]; } @@ -555,6 +533,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { new LongPointerWrapper(tadBuffers.getSecond().pointer()) ); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + return result; } @@ -574,65 +555,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { return toConcat[0]; return Nd4j.exec(new Concat(dimension, toConcat))[0]; - - // legacy implementation -/* - // if reusable var wasn't created for this thread, or is smaller then needed - set it to new value - if (extrazA.get() == null || extrazB.get() == null || extrazSize.get() == null || extrazSize.get() < toConcat.length) { - extrazA.set(new PointerPointer(toConcat.length)); - extrazB.set(new PointerPointer(toConcat.length)); - extrazSize.set(toConcat.length); - } - - PointerPointer shapeInfoPointers = extrazA.get(); - PointerPointer dataPointers = extrazB.get(); - int sumAlongDim = 0; - - long[] outputShape = ArrayUtil.copy(toConcat[0].shape()); - - boolean allScalars = true; - - for (int i = 0; i < toConcat.length; i++) { - Preconditions.checkState(toConcat[i].rank() == outputShape.length, "Encountered different array ranks for concat: input[0].shape()=%ndShape, input[%s].shape()=%ndShape", - toConcat[0], i, toConcat[i]); - - if (toConcat[i].isCompressed()) - Nd4j.getCompressor().decompressi(toConcat[i]); - - Preconditions.checkArgument(toConcat[i].dataType() == toConcat[0].dataType(), "All operands must have same data type: input 0 has type %s, input %s has type %s", - toConcat[0].dataType(), i, toConcat[i].dataType()); - - allScalars &= toConcat[i].rank() == 0; - - shapeInfoPointers.put(i, toConcat[i].shapeInfoDataBuffer().addressPointer()); - dataPointers.put(i, toConcat[i].data().addressPointer()); - sumAlongDim += toConcat[i].size(dimension); - for (int j = 0; j < toConcat[i].rank(); j++) { - - if (j != dimension && toConcat[i].size(j) != outputShape[j]) { - throw new IllegalArgumentException( - "Illegal concatenation at array " + i + " and shape element " + j); - } - } - } - - if (allScalars) { - outputShape = new long[]{sumAlongDim}; - } else { - outputShape[dimension] = sumAlongDim; - } - - INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order()); - - nativeOps.concat(null, dimension, toConcat.length, - dataPointers, shapeInfoPointers, - null, null, - ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), - null, null, - null, null); - - return ret; - */ } @@ -757,6 +679,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { (LongPointer) zTadShapeInfo, new LongPointerWrapper(zTadOffsets)); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); return ret; } @@ -794,6 +718,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { arrays.length, len); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + return target; } @@ -846,6 +773,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { len, true); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + return target; } @@ -983,6 +913,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { arrays.size(), ptrMap, tadPointers, offsetPointers); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); dataPointers.address(); shapePointers.address(); @@ -990,84 +922,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { offsetPointers.address(); } - - /** - * This method converts Half-precision databuffer to current dType buffer. - * - * @param buffer - * @return - */ - /* - @Override - public DataBuffer restoreFromHalfs(DataBuffer buffer) { - if (buffer.dataType() != DataType.COMPRESSED) - throw new IllegalStateException("DataBuffer contains wrong data: " + buffer.dataType()); - - CompressedDataBuffer comp = (CompressedDataBuffer) buffer; - CompressionDescriptor descriptor = comp.getCompressionDescriptor(); - - DataBuffer targetBuffer = Nd4j.createBuffer(descriptor.getCompressedLength() / 2); - - if (Nd4j.dataType() == DataType.DOUBLE) { - nativeOps.convertHalfsToDoubles( - null, - comp.addressPointer(), - (int) descriptor.getCompressedLength() / 2, - targetBuffer.addressPointer() - ); - } else if (Nd4j.dataType() == DataType.FLOAT) { - nativeOps.convertHalfsToFloats( - null, - comp.addressPointer(), - (int) descriptor.getCompressedLength() / 2, - targetBuffer.addressPointer() - ); - } else { - throw new UnsupportedOperationException("Target dtype isn't supported: " + Nd4j.dataType()); - } - - return targetBuffer; - } - */ - - /** - * This method converts Single/Double precision databuffer to Half-precision databuffer - * - * @param buffer - * @return - */ - /*@Override - public DataBuffer convertToHalfs(DataBuffer buffer) { - // we allocate pointer - ShortPointer pointer = new ShortPointer(buffer.length()); - - if (buffer.dataType() == DataType.DOUBLE) { - nativeOps.convertDoublesToHalfs( - null, - buffer.addressPointer(), - (int) buffer.length(), - pointer - ); - } else if (buffer.dataType() == DataType.FLOAT) { - nativeOps.convertFloatsToHalfs( - null, - buffer.addressPointer(), - (int) buffer.length(), - pointer - ); - } else { - throw new UnsupportedOperationException("Source dtype isn't supported: " + buffer.dataType()); - } - - CompressionDescriptor descriptor = new CompressionDescriptor(buffer, new Float16()); - descriptor.setCompressedLength(buffer.length() * 2); - - - CompressedDataBuffer result = new CompressedDataBuffer(pointer, descriptor); - return result; - } - */ - /** * This method converts Single/Double precision databuffer to Half-precision databuffer * @@ -1081,6 +935,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. "); DataBuffer buffer = convertDataEx(typeSrc, source.data(), typeDst); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + source.setData(buffer); if (buffer instanceof CompressedDataBuffer) @@ -1125,6 +982,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { convertDataEx(typeSrc, source, typeDst, buffer); + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); + return buffer; } @@ -1132,6 +992,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, long length) { nativeOps.convertTypes(null, typeSrc.ordinal(), source, length, typeDst.ordinal(), target); + + if (nativeOps.lastErrorCode() != 0) + throw new RuntimeException(nativeOps.lastErrorMessage()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 11373c440..e79c21feb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -234,6 +234,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { null); } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + profilingConfigurableHookOut(op, st); return op.z(); } @@ -563,6 +566,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + return ret; } @@ -644,6 +650,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw new UnsupportedOperationException(); } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); } public INDArray exec(ScalarOp op) { @@ -690,6 +698,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + profilingConfigurableHookOut(op, st); return op.z(); @@ -886,6 +897,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + profilingConfigurableHookOut(op, st); } @@ -962,6 +976,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw new UnsupportedOperationException("Unknown operation type: [" + op.getOpType() + "]"); } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); return op.z(); } @@ -1091,6 +1107,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), pointer, FlatBuffersMapper.getDataTypeAsByte(dataType)); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + } /** @@ -1197,6 +1216,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { numIndexArguments, intArrays, numIntArrays, block.getRealArgumentsPointer(), numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType)); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); } /** @@ -1284,6 +1305,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { op.extraArgsDataBuff(op.z().dataType()).addressPointer()); } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + profilingConfigurableHookOut(op, st); return op.z(); @@ -1370,6 +1394,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { (float) threshold); //long t2 = System.currentTimeMillis(); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + if (cntAbs < 2) return null; @@ -1429,6 +1456,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { loop.convertTypes(null, DataTypeEx.THRESHOLD.ordinal(), buffer.addressPointer(), target.length(), typeDst.ordinal(), target.data().addressPointer()); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + return target; } @@ -1460,6 +1490,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { (IntPointer) buffer.addressPointer(), (float) threshold); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + return affected; } @@ -1473,6 +1506,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { (LongPointer) target.shapeInfoDataBuffer().addressPointer() ); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + return target; } @@ -1673,136 +1709,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } catch (Exception e) { throw new RuntimeException("Op [" + name + "] execution failed", e); } -/* - val name = op.opName().toLowerCase(); - val hash = op.opHash(); - - if (name.equals("noop")) { - return op.outputArguments(); - } - - val inputShapes = getInputShapes(op.numInputArguments()); - val inputBuffers = getInputBuffers(op.numInputArguments()); - - int cnt= 0; - val inputArgs = op.inputArguments(); - for (val in: inputArgs) { - if(in == null) - throw new NullPointerException("Input argument is null for op " + op.getClass().getName()); - - if (!in.isEmpty()) - inputBuffers.put(cnt, in.data().addressPointer()); - - inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer()); - } - - val outputArgs = op.outputArguments(); - for(int i = 0; i < outputArgs.length; i++) { - if(outputArgs[i] == null) - throw new ND4JIllegalStateException("Op output arguments must not be null! Op " + op.getClass().getName()); - } - - - val outputShapes = getOutputShapes(op.numOutputArguments()); - val outputBuffers = getOutputBuffers(op.numOutputArguments()); - - cnt= 0; - for (val out: outputArgs) { - if(out.isEmpty()){ - outputBuffers.put(cnt, null); - } else { - outputBuffers.put(cnt, out.data().addressPointer()); - } - outputShapes.put(cnt++, out.shapeInfoDataBuffer().addressPointer()); - } - - val iArgs = op.numIArguments() > 0 ? getLongPointerFrom(iArgsPointer,op.numIArguments()) : null; - val tArgs = op.numTArguments() > 0 ? getDoublePointerFrom(tArgsPointer,op.numTArguments()) : null; - val bArgs = op.numBArguments() > 0 ? getBooleanPointerFrom(bArgsPointer,op.numBArguments()) : null; - - cnt = 0; - val iArgs1 = op.iArgs(); - for (val i: iArgs1) - iArgs.put(cnt++, i); - - cnt = 0; - val bArgs1 = op.bArgs(); - for (val b: bArgs1) - bArgs.put(cnt++, b); - - cnt = 0; - val tArgs1 = op.tArgs(); - for (val t: tArgs1) - tArgs.put(cnt++, t); - - val t = op.numInputArguments(); - - OpStatus status = OpStatus.ND4J_STATUS_OK; - try { - val code = loop.execCustomOp( - null, - hash, - inputBuffers, - inputShapes, - op.numInputArguments(), - outputBuffers, - outputShapes, - op.numOutputArguments(), - tArgs, op.numTArguments(), - iArgs, op.numIArguments(), - bArgs, op.numBArguments(), - op.isInplaceCall()); - - status = OpStatus.byNumber(code); - - if (status != OpStatus.ND4J_STATUS_OK) - throw new ND4JIllegalStateException("Failed to execute op [" + name + "] with error code [" + status +"]"); - }catch(Exception e) { - val sb = new StringBuilder(); - sb.append("Inputs: [("); - for( int i=0; i 0) - sb.append("), ("); - sb.append(Shape.shapeToStringShort(inputArgs[i])); - } - sb.append(")]. Outputs: [("); - for( int i=0; i 0) - sb.append("), ("); - sb.append(Shape.shapeToStringShort(outputArgs[i])); - } - sb.append(")]. tArgs: "); - if(op.numTArguments() > 0){ - sb.append(Arrays.toString(op.tArgs())); - } else { - sb.append("-"); - } - sb.append(". iArgs: "); - if(op.numIArguments() > 0){ - sb.append(Arrays.toString(op.iArgs())); - } else { - sb.append("-"); - } - if(op instanceof DifferentialFunction){ - String n = ((DifferentialFunction) op).getOwnName(); - if(n != null && !n.equals(op.opName())){ - sb.append(". Op own name: \"").append(n).append("\""); - } - } - log.error("Failed to execute op " + op.opName() + ". Attempted to execute with " + - String.valueOf(op.numInputArguments()) + " inputs, " + - String.valueOf(op.numOutputArguments()) + " outputs, "+ - String.valueOf(op.numTArguments()) + " targs and " + - String.valueOf(op.numIArguments()) + " iargs. " + - sb.toString() + - " - Please see above message (printed out from c++) for a possible cause of error."); - throw e; - } - - profilingConfigurableHookOut(op, st); - - return op.outputArguments(); - */ } protected LongShapeDescriptor getShapeFromPointer(LongPointer ptr) { @@ -1870,6 +1776,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { ptrptr = loop.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments()); + + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); } catch (Throwable t){ StringBuilder sb = new StringBuilder(); sb.append("Inputs: [("); @@ -1893,6 +1802,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { throw t; } + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + if (ptrptr == null) throw new RuntimeException(); @@ -1929,6 +1841,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public void registerGraph(long id, Pointer graph) { loop.registerGraph(null, id, graph); + + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); } @Override @@ -1952,7 +1867,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val newMap = new LinkedHashMap(); - OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); + OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); + + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result)); @@ -1996,6 +1914,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public void forgetGraph(long id) { loop.unregisterGraph(null, id); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); } /** @@ -2055,6 +1975,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { array.data().addressPointer(), (LongPointer) tadX.getFirst().addressPointer(), (LongPointer) tadX.getSecond().addressPointer(), null, null, null, updates.data().addressPointer(), (LongPointer) tadY.getFirst().addressPointer(), (LongPointer) tadY.getSecond().addressPointer(), null, null, null, (IntPointer) indices.data().addressPointer(), null); + + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); } @Override @@ -2078,6 +2001,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val status = loop.execCustomOp2(null, op.opHash(), context.contextPointer()); + + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + if (status != 0) throw new RuntimeException("Op [" + op.opName() + "] execution failed"); @@ -2155,6 +2082,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { loop.inspectArray(null, array.data().addressPointer(), (LongPointer) array.shapeInfoDataBuffer().addressPointer(), null, null, debugInfo); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + return INDArrayStatistics.builder() .minValue(debugInfo._minValue()) .maxValue(debugInfo._maxValue()) @@ -2171,6 +2101,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length)); @@ -2183,6 +2115,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack)); val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack)); @@ -2205,11 +2140,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public String runLightBenchmarkSuit(boolean printOut) { - return loop.runLightBenchmarkSuit(printOut); + val s = loop.runLightBenchmarkSuit(printOut); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + + return s; } @Override public String runFullBenchmarkSuit(boolean printOut) { - return loop.runFullBenchmarkSuit(printOut); + val s = loop.runFullBenchmarkSuit(printOut); + if (loop.lastErrorCode() != 0) + throw new RuntimeException(loop.lastErrorMessage()); + + return s; } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 8e71816f8..eeb4d38c3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -467,6 +467,60 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // #endif //DEV_TESTS_TADPACK_H +// Parsed from execution/ErrorReference.h + +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +// #ifndef DEV_TESTS_ERRORREFERENCE_H +// #define DEV_TESTS_ERRORREFERENCE_H + +// #include +// #include + @Namespace("sd") @NoOffset public static class ErrorReference extends Pointer { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public ErrorReference(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public ErrorReference(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public ErrorReference position(long position) { + return (ErrorReference)super.position(position); + } + + public ErrorReference() { super((Pointer)null); allocate(); } + private native void allocate(); + + public native int errorCode(); + public native @Cast("char*") String errorMessage(); + + public native void setErrorCode(int errorCode); + public native void setErrorMessage(@StdString BytePointer message); + public native void setErrorMessage(@StdString String message); + } + + + +// #endif //DEV_TESTS_ERRORREFERENCE_H + + // Parsed from Environment.h /******************************************************************************* @@ -688,6 +742,18 @@ bool verbose = false; // #include // #include +/** + * This function returns last error code stored, + * @return non-zero if something bad happened + */ +public native int lastErrorCode(); + +/** + * This function returns last error message, if last error code > 0 + * @return + */ +public native @Cast("char*") String lastErrorMessage(); + /** * * @param p @@ -1710,72 +1776,6 @@ public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraP @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); - -/** -* Append an input array -* to the end of a flat array -* in a particular order -* @param offset the offset of the array to start at -* @param order the order -* @param result the result array -* @param resultShapeInfo the shape info for te array -* @param input the input for the array -* @param inputShapeInfo the shape information for that array -*/ -public native void flatten( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int offset, - char order, - Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo, - Pointer input, @Cast("Nd4jLong*") LongPointer inputShapeInfo, - Pointer dinput, @Cast("Nd4jLong*") LongPointer dinputShapeInfo); -public native void flatten( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int offset, - char order, - Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo, - Pointer input, @Cast("Nd4jLong*") LongBuffer inputShapeInfo, - Pointer dinput, @Cast("Nd4jLong*") LongBuffer dinputShapeInfo); -public native void flatten( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int offset, - char order, - Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo, - Pointer input, @Cast("Nd4jLong*") long[] inputShapeInfo, - Pointer dinput, @Cast("Nd4jLong*") long[] dinputShapeInfo); - -public native void concat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - @Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo, - Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void concat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - @Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo, - Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers); -public native void concat( - @Cast("Nd4jPointer*") PointerPointer extraPointers, - int dimension, - int numArrays, - @Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo, - @Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo, - Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo, - Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo, - @Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers); - - public native void specialConcat( @Cast("Nd4jPointer*") PointerPointer extraPointers, int dimension, @@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc // #include // #include // #include +// #include // #include // #include // #include @@ -16496,15 +16497,18 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * creates identity 2D matrix or batch of identical 2D identity matrices - * + * * Input array: * provide some array - in any case operation simply neglects it - * + * + * Input float argument (if passed): + * TArgs[0] - type of elements of output array, default value is 5 (float) + * * Input integer arguments: * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order * IArgs[1] - the number of rows in output inner-most 2D identity matrix * IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows - * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape + * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape */ // #if NOT_EXCLUDED(OP_eye) @Namespace("nd4j::ops") public static class eye extends DeclarableCustomOp { @@ -16598,10 +16602,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * clip a list of given tensors with given average norm when needed - * + * * Input: * a list of tensors (at least one) - * + * * Input floating point argument: * clip_norm - a value that used as threshold value and norm to be used * @@ -16749,12 +16753,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * returns histogram (as 1D array) with fixed bins width - * + * * Input arrays: - * - input array with elements to be binned into output histogram + * - input array with elements to be binned into output histogram * - range array with first element being bottom limit and second element being top limit of histogram, please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1] - * + * * Input integer arguments: * nbins (optional) - number of histogram bins, default value is 100 */ @@ -21822,7 +21826,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_shift_bits) - @Namespace("nd4j::ops") public static class shift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class shift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public shift_bits(Pointer p) { super(p); } @@ -21835,7 +21839,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -21847,7 +21850,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_rshift_bits) - @Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class rshift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public rshift_bits(Pointer p) { super(p); } @@ -21860,7 +21863,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -21872,7 +21874,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_cyclic_shift_bits) - @Namespace("nd4j::ops") public static class cyclic_shift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class cyclic_shift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public cyclic_shift_bits(Pointer p) { super(p); } @@ -21885,7 +21887,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public cyclic_shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -21897,7 +21898,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - @Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class cyclic_rshift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public cyclic_rshift_bits(Pointer p) { super(p); } @@ -21910,6 +21911,30 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public cyclic_rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); + } +// #endif + + /** + * This operation returns hamming distance based on bits + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bits_hamming_distance) + @Namespace("nd4j::ops") public static class bits_hamming_distance extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bits_hamming_distance(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bits_hamming_distance position(long position) { + return (bits_hamming_distance)super.position(position); + } + + public bits_hamming_distance() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -22877,6 +22902,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include +// #include @Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -22912,6 +22938,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native void setScalarBuffer(Pointer pointer); public native void setAllocationBuffer(Pointer pointer); + public native ErrorReference errorReference(); + public native void triggerOwnership(@Cast("bool") boolean isOwner); public native int deviceId(); @@ -22961,6 +22989,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { static { Loader.load(); } @@ -22985,9 +23014,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native int getDeviceID(); public native void setDeviceID(int deviceID); + public native ErrorReference errorReference(); public static native @Cast("bool") boolean isInitialized(); public static native void releaseBuffers(); + + public static native LaunchContext defaultContext(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 58a2a7d02..554016686 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -38,6 +38,7 @@ import java.util.Scanner; "array/ConstantDataBuffer.h", "array/ConstantDescriptor.h", "array/TadPack.h", + "execution/ErrorReference.h", "Environment.h", "types/utf8string.h", "NativeOps.h", diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index 99731899f..5f1d372ff 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -159,7 +159,7 @@ nd4j-tests-cpu - true + false diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 9437ad7b2..057f610bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -1112,12 +1112,12 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testLayerNorm() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); - final INDArray bias = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); + final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain).addRowVector(bias); final INDArray expOut = res.norm1(); @@ -1126,38 +1126,70 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable sdInput = sd.var("input", standardized); SDVariable sdGain = sd.var("gain", gain); SDVariable sdBias = sd.var("bias", bias); - SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis); + SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis); out.norm1("out"); String err = OpValidation.validate(new TestCase(sd) .expectedOutput("out", expOut) .gradientCheck(true)); - assertNull(err, err); + assertNull(err); } + @Test + public void testLayerNorm4d() { + int mb = 3; + int ch = 4; + for(boolean nchw : new boolean[]{true, false}) { + double eps = 0.0; + INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch}); + INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); + INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); + INDArray mean = x.mean(true, 1, 2, 3); + INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1); + + INDArray standardized = x.sub(mean).div(std); + INDArray exp = standardized.mul(gain4d).add(bias4d); + + final int[] axis = new int[]{1, 2, 3}; + SameDiff sd = SameDiff.create(); + SDVariable sdInput = sd.var("input", x); + SDVariable sdGain = sd.var("gain", gain4d.reshape(ch)); + SDVariable sdBias = sd.var("bias", bias4d.reshape(ch)); + SDVariable out = sd.nn.layerNorm("layernorm", sdInput, sdGain, sdBias, nchw, axis); + + SDVariable loss = sd.loss.l2Loss(out); + + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("layernorm", exp) + .gradientCheck(true)); + assertNull(err); + } + } + + @Test public void testLayerNormOP() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); - final INDArray bias = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); + final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain).addRowVector(bias); final INDArray output = Nd4j.zerosLike(res); - Nd4j.getExecutioner().exec(new LayerNorm(standardized, gain, bias, output, 1)); + Nd4j.getExecutioner().exec(new LayerNorm(standardized, gain, bias, output, true, 1)); assertEquals(res, output); } @Test public void testLayerNormNoBias() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain); final INDArray expOut = res.norm1(); @@ -1165,7 +1197,7 @@ public class LayerOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable sdInput = sd.var("input", standardized); SDVariable sdGain = sd.var("gain", gain); - SDVariable out = sd.nn.layerNorm(sdInput, sdGain, axis); + SDVariable out = sd.nn.layerNorm(sdInput, sdGain, true, axis); out.norm1("out"); String err = OpValidation.validate(new TestCase(sd) @@ -1176,22 +1208,22 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testLayerNormOPNoBias() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE,4); final INDArray res = standardized.mulRowVector(gain); final INDArray output = Nd4j.zerosLike(res); - Nd4j.getExecutioner().exec(new LayerNorm(standardized, gain, output, 1)); + Nd4j.getExecutioner().exec(new LayerNorm(standardized, gain, output, true, 1)); assertEquals(res, output); } @Test public void testLayerNormNoDeviation() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); for (int i = 0; i < 4; i++) { random.putScalar(1,i, 7); } @@ -1199,8 +1231,8 @@ public class LayerOpValidation extends BaseOpValidation { final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); - final INDArray bias = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); + final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain).addRowVector(bias); final INDArray expOut = res.norm1(); @@ -1209,7 +1241,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable sdInput = sd.var("input", standardized); SDVariable sdGain = sd.var("gain", gain); SDVariable sdBias = sd.var("bias", bias); - SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, axis); + SDVariable out = sd.nn.layerNorm(sdInput, sdGain, sdBias, true, axis); out.norm1("out"); String err = OpValidation.validate(new TestCase(sd) @@ -1297,12 +1329,11 @@ public class LayerOpValidation extends BaseOpValidation { } @Test - @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testLayerNormMixedOrders(){ Nd4j.getRandom().setSeed(12345); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); - INDArray gain = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f'); - INDArray bias = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f'); + INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); + INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 2a4b032b5..eb228bf1f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -438,6 +438,27 @@ public class MiscOpValidation extends BaseOpValidation { assertEquals(failed.toString(), 0, failed.size()); } + @Test + public void testScatterUpdate(){ + INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3); + INDArray updates = Nd4j.create(new float[][]{ + {100, 101, 102}, + {200, 201, 202}}); + INDArray indices = Nd4j.createFromArray(2, 5); + + INDArray exp = x.dup(); + exp.putRow(2, updates.getRow(0)); + exp.putRow(5, updates.getRow(1)); + + INDArray out = exp.ulike(); + Nd4j.exec(DynamicCustomOp.builder("scatter_upd") + .addInputs(x, indices, updates) + .addOutputs(out) + .build()); + + assertEquals(exp, out); + } + @Test public void testGatherGradient() { Nd4j.getRandom().setSeed(12345); @@ -1688,4 +1709,59 @@ public class MiscOpValidation extends BaseOpValidation { Nd4j.getExecutioner().exec(op); } + + @Test + public void testHistogramFixedWidth(){ + //Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf] + INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9); + INDArray range = Nd4j.createFromArray(0.0, 1.0); + INDArray n = Nd4j.scalar(5); + + INDArray out = Nd4j.create(DataType.INT, 5); + + Nd4j.exec(DynamicCustomOp.builder("histogram_fixed_width") + .addInputs(in, range, n) + .addOutputs(out) + .build()); + + INDArray exp = Nd4j.createFromArray(3, 1, 2, 0, 1); + assertEquals(exp, out); + } + + @Test + public void testDynamicPartition(){ + INDArray data = Nd4j.createFromArray(2, 1, 2, 0); + INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); + INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) + .addIntegerArguments(3) //3 partitions + .addInputs(data, partitions).build()); + + INDArray exp0 = Nd4j.createFromArray(2, 0); + INDArray exp1 = Nd4j.createFromArray(2); + INDArray exp2 = Nd4j.createFromArray(1); + + assertEquals(exp0, out[0]); //Usually just gives [0,0] + assertEquals(exp1, out[1]); + assertEquals(exp2, out[2]); + } + + @Test + public void testListDiff(){ + INDArray x = Nd4j.createFromArray(0, 1, 2, 3); + INDArray y = Nd4j.createFromArray(3, 1); + + INDArray out = Nd4j.create(DataType.INT, 2); + INDArray outIdx = Nd4j.create(DataType.INT, 2); + + Nd4j.exec(DynamicCustomOp.builder("listdiff") + .addInputs(x, y) + .addOutputs(out, outIdx) + .build()); + + INDArray exp = Nd4j.createFromArray(0, 2); + + assertEquals(exp, out); //Values in x not in y + assertEquals(exp, outIdx); //Indices of the values in x not in y + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 646cae454..8d64f6404 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; @@ -371,6 +372,14 @@ public class RandomOpValidation extends BaseOpValidation { assertNull(OpValidation.validate(tc)); } + } + @Test + public void testAllEmptyReduce(){ + INDArray x = Nd4j.createFromArray(true, true, true); + All all = new All(x); + all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction) + INDArray out = Nd4j.exec(all); + assertEquals(x, out); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index e53cfa5ff..bf4b331a3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1342,6 +1342,26 @@ public class ShapeOpValidation extends BaseOpValidation { assertEquals(failed.toString(), 0, failed.size()); } + @Test + public void testSegmentMean(){ + INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3); + INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2); + + INDArray out = Nd4j.create(DataType.FLOAT, 3, 3); + + Nd4j.exec(DynamicCustomOp.builder("segment_mean") + .addInputs(x, segmentIds) + .addOutputs(out) + .build()); + + INDArray exp = out.like(); + exp.putRow(0, x.getRow(0).add(x.getRow(1)).muli(0.5)); + exp.putRow(1, x.getRow(2).add(x.getRow(3)).muli(0.5)); + exp.putRow(2, x.getRow(4).add(x.getRow(5)).muli(0.5)); + + assertEquals(exp, out); + } + @Test public void testSequenceMask() { OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue? @@ -1599,21 +1619,6 @@ public class ShapeOpValidation extends BaseOpValidation { assertEquals(expected, result.eval()); } - @Test - public void testBroadcast() { - OpValidationSuite.ignoreFailing(); - SameDiff sd = SameDiff.create(); - SDVariable in = sd.var("in", Nd4j.rand(3, 4)); - SDVariable broadcast = sd.f().broadcast(in, 3, 4, 5); - - INDArray out = sd.execAndEndResult(); - assertArrayEquals(new long[]{3, 4, 5}, out.shape()); - - for (int i = 0; i < 5; i++) { - assertEquals(in.getArr(), out.get(all(), all(), point(i))); - } - } - @Test public void testSlice2d() { @@ -2109,6 +2114,38 @@ public class ShapeOpValidation extends BaseOpValidation { Nd4j.exec(op); } + @Test + public void testConcatEmpty2(){ + INDArray empty10a = Nd4j.create(DataType.INT, 1, 0); + INDArray empty10b = Nd4j.create(DataType.INT, 1, 0); + + DynamicCustomOp op = DynamicCustomOp.builder("concat") + .addInputs(empty10a, empty10b) + .addIntegerArguments(0) //axis = 0 + .build(); + + List l = op.calculateOutputShape(); + assertEquals(1, l.size()); + assertTrue(l.get(0).isEmpty()); + assertArrayEquals(new long[]{2, 0}, l.get(0).getShape()); + assertEquals(DataType.INT, l.get(0).dataType()); + + op.addOutputArgument(Nd4j.create(DataType.INT, 2, 0)); + Nd4j.exec(op); + + + op = DynamicCustomOp.builder("concat") + .addInputs(empty10a, empty10b) + .addIntegerArguments(1) //axis = 1 + .build(); + l = op.calculateOutputShape(); + assertEquals(1, l.size()); + assertTrue(l.get(0).isEmpty()); + assertArrayEquals(new long[]{1, 0}, l.get(0).getShape()); + op.addOutputArgument(Nd4j.create(DataType.INT, 1, 0)); + Nd4j.exec(op); + } + @Test public void testEmptyGather(){ /* @@ -2434,4 +2471,5 @@ public class ShapeOpValidation extends BaseOpValidation { .addInputs(Nd4j.createFromArray(1, 0)) .build(); } + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 0d177027d..fdd2b3160 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -412,7 +412,25 @@ public class TransformOpValidation extends BaseOpValidation { .expectedOutput("dp0", expOut[0]) .expectedOutput("dp1", expOut[1]) .gradientCheck(true)); - assertNull(err, err); + assertNull(err); + } + + @Test + public void testDynamicPartition2(){ + INDArray data = Nd4j.createFromArray(2, 1, 2, 0); + INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); + INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) + .addIntegerArguments(3) //3 partitions + .addInputs(data, partitions).build()); + + INDArray exp0 = Nd4j.createFromArray(2, 0); + INDArray exp1 = Nd4j.createFromArray(2); + INDArray exp2 = Nd4j.createFromArray(1); + + assertEquals(exp0, out[0]); //Usually just gives [0,0] + assertEquals(exp1, out[1]); + assertEquals(exp2, out[2]); } @Test @@ -1612,6 +1630,27 @@ public class TransformOpValidation extends BaseOpValidation { } } + @Test + public void testTopK1(){ + INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); + INDArray k = Nd4j.scalar(1); + INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); + INDArray outIdx = Nd4j.create(DataType.INT, 1); + + Nd4j.exec(DynamicCustomOp.builder("top_k") + .addInputs(x, k) + .addOutputs(outValue, outIdx) + .addBooleanArguments(false) //not sorted + .addIntegerArguments(1) + .build()); + + INDArray expValue = Nd4j.createFromArray(10.0); + INDArray expIdx = Nd4j.createFromArray(3); + + assertEquals(expValue, outValue); + assertEquals(expIdx, outIdx); + } + @Test public void testInTopK() { for( int k=4; k>= 1; k--){ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index ef6d1268b..bb15b3392 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -53,7 +53,6 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; -import org.nd4j.linalg.api.ops.impl.layers.Linear; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; @@ -522,39 +521,6 @@ public class SameDiffTests extends BaseNd4jTest { } - @Test - public void testLinearModule() { - int nIn = 5; - Linear linear = Linear.execBuilder() - .nIn(nIn) - .nOut(4) - .weightInitScheme(new UniformInitScheme('f', nIn)) - .biasWeightInitScheme(new ZeroInitScheme('f')) - .build(); - linear.exec(Nd4j.linspace(1, 20, 20).reshape(4, 5)); - assertEquals(1, linear.numOutputArguments()); - - } - - - @Test - public void testLinearModule2() { - Linear linear = Linear.execBuilder() - .nIn(3) - .nOut(2) - .weightInitScheme(new OneInitScheme('f')) - .biasWeightInitScheme(new ZeroInitScheme('f')) - .build(); - linear.exec(Nd4j.linspace(1, 6, 6).reshape(2, 3)); - INDArray assertion = Nd4j.create(new double[][]{ - {6, 6}, - {15, 15} - }); - assertEquals(assertion, linear.outputArguments()[0]); - - } - - @Test public void testDefineFunctionArrayExistence() { SameDiff sameDiff = SameDiff.create(); @@ -3577,4 +3543,24 @@ public class SameDiffTests extends BaseNd4jTest { assertEquals(e, mod.eval()); } + + @Test + public void castShapeTest1(){ + SameDiff sd = SameDiff.create(); + SDVariable x = sd.constant(Nd4j.createFromArray(1, 2, 3, 4)); + SDVariable casted = x.castTo(DataType.FLOAT); + + assertEquals(casted.dataType(), DataType.FLOAT); + } + + @Test + @Ignore // casted shape is null + public void castShapeTestEmpty(){ + SameDiff sd = SameDiff.create(); + SDVariable x = sd.constant(Nd4j.empty(DataType.INT)); + SDVariable casted = x.castTo(DataType.FLOAT); + + assertEquals(casted.dataType(), DataType.FLOAT); + assertTrue(casted.getShapeDescriptor().isEmpty()); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java index d182377fe..1bd6fd22c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java @@ -291,7 +291,7 @@ public class RegressionEvalTest extends BaseNd4jTest { for (Metric m : Metric.values()) { double d1 = e4d.scoreForMetric(m); double d2 = e2d.scoreForMetric(m); - assertEquals(m.toString(), d2, d1, 1e-6); + assertEquals(m.toString(), d2, d1, 1e-5); } } @@ -385,7 +385,7 @@ public class RegressionEvalTest extends BaseNd4jTest { for(Metric m : Metric.values()){ double d1 = e4d_m1.scoreForMetric(m); double d2 = e2d_m1.scoreForMetric(m); - assertEquals(m.toString(), d2, d1, 1e-6); + assertEquals(m.toString(), d2, d1, 1e-5); } //Check per-output masking: diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index ab86f829e..ecc81c981 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -344,6 +344,8 @@ public class TFGraphTestAllHelper { System.out.println("Pass: " + varName); } else { System.out.println("FAIL: " + varName); + System.out.println("TF:\n" + tfValue); + System.out.println("SD:\n" + sdVal); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 4f3520d31..1fef4a07b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -180,8 +180,8 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); try { - TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, - TFGraphTestAllHelper.LOADER, maxRE, minAbs); + TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs); + //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); } catch (Throwable t){ log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t); throw t; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index b302c8c0f..22b911468 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -1334,7 +1334,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray zC = Nd4j.create(shape, 'c'); zC.setData(Nd4j.linspace(1, 24, 24, DataType.DOUBLE).data()); for (int tad = 0; tad < zC.tensorsAlongDimension(dim); tad++) { - INDArray javaTad = zC.javaTensorAlongDimension(tad, dim); + INDArray javaTad = zC.tensorAlongDimension(tad, dim); System.out.println("Tad " + tad + " is " + zC.tensorAlongDimension(tad, dim)); } @@ -5216,6 +5216,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + log.info("Array shapeInfo: {}", array.shapeInfoJava()); + INDArray rev = Nd4j.reverse(array); assertEquals(exp, rev); @@ -5226,7 +5228,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0]; + INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, array.ulike()))[0]; assertEquals(exp, rev); } @@ -5236,7 +5238,7 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0]; + INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array,array.ulike()))[0]; assertEquals(exp, rev); } @@ -5335,11 +5337,103 @@ public class Nd4jTestsC extends BaseNd4jTest { assertNotNull(lsd); //Fails here on CUDA, OK on native/cpu } + @Test + public void testReverseSmall_1() { + val array = Nd4j.linspace(1, 10, 10, DataType.INT); + val exp = array.dup(array.ordering()); + + Transforms.reverse(array, false); + Transforms.reverse(array, false); + + val jexp = exp.data().asInt(); + val jarr = array.data().asInt(); + assertArrayEquals(jexp, jarr); + assertEquals(exp, array); + } + + @Test + public void testReverseSmall_2() { + val array = Nd4j.linspace(1, 10, 10, DataType.INT); + val exp = array.dup(array.ordering()); + + val reversed = Transforms.reverse(array, true); + val rereversed = Transforms.reverse(reversed, true); + + val jexp = exp.data().asInt(); + val jarr = rereversed.data().asInt(); + assertArrayEquals(jexp, jarr); + assertEquals(exp, rereversed); + } + + @Test + public void testReverseSmall_3() { + val array = Nd4j.linspace(1, 11, 11, DataType.INT); + val exp = array.dup(array.ordering()); + + Transforms.reverse(array, false); + + log.info("Reversed shapeInfo: {}", array.shapeInfoJava()); + log.info("Reversed: {}", array); + + Transforms.reverse(array, false); + + val jexp = exp.data().asInt(); + val jarr = array.data().asInt(); + assertArrayEquals(jexp, jarr); + assertEquals(exp, array); + } + + @Test + public void testReverseSmall_4() { + val array = Nd4j.linspace(1, 11, 11, DataType.INT); + val exp = array.dup(array.ordering()); + + val reversed = Transforms.reverse(array, true); + + log.info("Reversed: {}", reversed); + + val rereversed = Transforms.reverse(reversed, true); + + val jexp = exp.data().asInt(); + val jarr = rereversed.data().asInt(); + assertArrayEquals(jexp, jarr); + assertEquals(exp, rereversed); + } + + @Test + public void testReverse_1() { + val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT); + val exp = array.dup(array.ordering()); + + Transforms.reverse(array, false); + Transforms.reverse(array, false); + + val jexp = exp.data().asInt(); + val jarr = array.data().asInt(); + assertArrayEquals(jexp, jarr); + assertEquals(exp, array); + } + + @Test + public void testReverse_2() { + val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT); + val exp = array.dup(array.ordering()); + + val reversed = Transforms.reverse(array, true); + val rereversed = Transforms.reverse(reversed, true); + + val jexp = exp.data().asInt(); + val jarr = rereversed.data().asInt(); + assertArrayEquals(jexp, jarr); + assertEquals(exp, rereversed); + } + @Test public void testNativeSort3_1() { INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1); INDArray exp = array.dup(); Transforms.reverse(array, false); + log.info("Reverse: {}", array); long time1 = System.currentTimeMillis(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java index ab724ee1b..b49d49b79 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java @@ -100,20 +100,7 @@ public class IndexingTests extends BaseNd4jTest { INDArray vals = Nd4j.valueArrayOf(new long[] {2,2,2,2},5, DataType.DOUBLE); assertEquals(vals,x); } - - - @Test @Ignore - public void testIndexGetDuplicate() { - List> indices = new ArrayList<>(); - indices.add(Arrays.asList(0,0)); - INDArray linspace = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape('c',2,2,2,2).castTo(DataType.DOUBLE); - INDArray get = linspace.get(indices); - INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}).reshape('c',2,2,2,2).castTo(DataType.DOUBLE); - assertEquals(assertion,get); - } - - - + @Test public void testGetScalar() { INDArray arr = Nd4j.linspace(1, 5, 5, DataType.DOUBLE); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java index 0904bdaee..960bbd646 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java @@ -16,7 +16,6 @@ package org.nd4j.linalg.api.indexing; -import org.joda.time.Interval; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java index 755fd4d20..cede0543a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/tad/TestTensorAlongDimension.java @@ -57,7 +57,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < n; i++) { StopWatch javaTiming = new StopWatch(); javaTiming.start(); - row.javaTensorAlongDimension(0, 0); + row.tensorAlongDimension(0, 0); javaTiming.stop(); StopWatch cTiming = new StopWatch(); cTiming.start(); @@ -98,7 +98,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { assertEquals(cols, arr.tensorsAlongDimension(0)); for (int i = 0; i < cols; i++) { INDArray tad = arr.tensorAlongDimension(i, 0); - INDArray javaTad = arr.javaTensorAlongDimension(i, 0); + INDArray javaTad = arr.tensorAlongDimension(i, 0); assertEquals(javaTad, tad); assertArrayEquals(new int[] {rows}, tad.shape()); //assertEquals(testValues.javaTensorAlongDimension(i, 0), tad); @@ -120,7 +120,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { list = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, new int[]{rows, cols, dim2}, DataType.DOUBLE); for (Pair p : list) { INDArray arr = p.getFirst().assign(testValues); - INDArray javaTad = arr.javaTensorAlongDimension(0, 0); + INDArray javaTad = arr.tensorAlongDimension(0, 0); INDArray tadTest = arr.tensorAlongDimension(0, 0); assertEquals(javaTad, tadTest); //Along dimension 0: expect row vector with length 'rows' @@ -165,7 +165,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { //Along dimension 0,1: expect matrix with shape [rows,cols] assertEquals(dim2, arr.tensorsAlongDimension(0, 1)); for (int i = 0; i < dim2; i++) { - INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 1); + INDArray javaTad = arr.tensorAlongDimension(i, 0, 1); INDArray tad = arr.tensorAlongDimension(i, 0, 1); int javaEleStride = javaTad.elementWiseStride(); int testTad = tad.elementWiseStride(); @@ -178,11 +178,11 @@ public class TestTensorAlongDimension extends BaseNd4jTest { //Along dimension 0,2: expect matrix with shape [rows,dim2] assertEquals(cols, arr.tensorsAlongDimension(0, 2)); for (int i = 0; i < cols; i++) { - INDArray javaTad = arr.javaTensorAlongDimension(i, 0, 2); + INDArray javaTad = arr.tensorAlongDimension(i, 0, 2); INDArray tad = arr.tensorAlongDimension(i, 0, 2); assertEquals(javaTad, tad); assertArrayEquals(new long[] {rows, dim2}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad); + assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad); } //Along dimension 1,2: expect matrix with shape [cols,dim2] @@ -190,7 +190,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < rows; i++) { INDArray tad = arr.tensorAlongDimension(i, 1, 2); assertArrayEquals(new long[] {cols, dim2}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad); + assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad); } } @@ -207,7 +207,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < dim2 * dim3; i++) { INDArray tad = arr.tensorAlongDimension(i, 0, 1); assertArrayEquals(new long[] {rows, cols}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 0, 1), tad); + assertEquals(testValues.tensorAlongDimension(i, 0, 1), tad); } //Along dimension 0,2: expect matrix with shape [rows,dim2] @@ -215,7 +215,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < cols * dim3; i++) { INDArray tad = arr.tensorAlongDimension(i, 0, 2); assertArrayEquals(new long[] {rows, dim2}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 0, 2), tad); + assertEquals(testValues.tensorAlongDimension(i, 0, 2), tad); } //Along dimension 0,3: expect matrix with shape [rows,dim3] @@ -223,7 +223,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < cols * dim2; i++) { INDArray tad = arr.tensorAlongDimension(i, 0, 3); assertArrayEquals(new long[] {rows, dim3}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 0, 3), tad); + assertEquals(testValues.tensorAlongDimension(i, 0, 3), tad); } @@ -232,7 +232,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < rows * dim3; i++) { INDArray tad = arr.tensorAlongDimension(i, 1, 2); assertArrayEquals(new long[] {cols, dim2}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 1, 2), tad); + assertEquals(testValues.tensorAlongDimension(i, 1, 2), tad); } //Along dimension 1,3: expect matrix with shape [cols,dim3] @@ -240,7 +240,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < rows * dim2; i++) { INDArray tad = arr.tensorAlongDimension(i, 1, 3); assertArrayEquals(new long[] {cols, dim3}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 1, 3), tad); + assertEquals(testValues.tensorAlongDimension(i, 1, 3), tad); } //Along dimension 2,3: expect matrix with shape [dim2,dim3] @@ -248,7 +248,7 @@ public class TestTensorAlongDimension extends BaseNd4jTest { for (int i = 0; i < rows * cols; i++) { INDArray tad = arr.tensorAlongDimension(i, 2, 3); assertArrayEquals(new long[] {dim2, dim3}, tad.shape()); - assertEquals(testValues.javaTensorAlongDimension(i, 2, 3), tad); + assertEquals(testValues.tensorAlongDimension(i, 2, 3), tad); } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java index 75a91263d..cfff870a6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/SpecialTests.java @@ -551,15 +551,13 @@ public class SpecialTests extends BaseNd4jTest { int[] inputShape = new int[]{1, 2, 2, 1}; int M = 2; - int[] blockShape = new int[]{M, 1}; - int[] paddingShape = new int[]{M, 2}; INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); - INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT); - INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT); + INDArray blocks = Nd4j.createFromArray(2, 2); + INDArray padding = Nd4j.createFromArray(0, 0, 0, 0).reshape(2,2); INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1); - val op = DynamicCustomOp.builder("space_to_batch") + val op = DynamicCustomOp.builder("space_to_batch_nd") .addInputs(input, blocks, padding) .addOutputs(expOut).build(); Nd4j.getExecutioner().execAndReturn(op); @@ -573,15 +571,13 @@ public class SpecialTests extends BaseNd4jTest { int[] inputShape = new int[]{miniBatch, 1, 1, 1}; int M = 2; - int[] blockShape = new int[]{M, 1}; - int[] cropShape = new int[]{M, 2}; INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); - INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT); - INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT); + INDArray blocks = Nd4j.createFromArray(2, 2); + INDArray crops = Nd4j.createFromArray(0, 0, 0, 0).reshape(2,2); INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1); - DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space") + DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space_nd") .addInputs(input, blocks, crops) .addOutputs(expOut).build(); Nd4j.getExecutioner().execAndReturn(op); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 6c4595a0c..f325348fb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -733,6 +733,46 @@ public class CustomOpsTests extends BaseNd4jTest { } } + @Test + public void testListDiff(){ + INDArray x = Nd4j.createFromArray(0, 1, 2, 3); + INDArray y = Nd4j.createFromArray(3, 1); + + INDArray out = Nd4j.create(DataType.INT, 2); + INDArray outIdx = Nd4j.create(DataType.INT, 2); + + Nd4j.exec(DynamicCustomOp.builder("listdiff") + .addInputs(x, y) + .addOutputs(out, outIdx) + .build()); + + INDArray exp = Nd4j.createFromArray(0, 2); + + assertEquals(exp, out); //Values in x not in y + assertEquals(exp, outIdx); //Indices of the values in x not in y + } + + @Test + public void testTopK1(){ + INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); + INDArray k = Nd4j.scalar(1); + INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); + INDArray outIdx = Nd4j.create(DataType.INT, 1); + + Nd4j.exec(DynamicCustomOp.builder("top_k") + .addInputs(x, k) + .addOutputs(outValue, outIdx) + .addBooleanArguments(false) //not sorted + .addIntegerArguments(1) + .build()); + + INDArray expValue = Nd4j.createFromArray(10.0); + INDArray expIdx = Nd4j.createFromArray(3); + + assertEquals(expValue, outValue); + assertEquals(expIdx, outIdx); + } + @Test public void testMaxPool2Dbp_1() { val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index 49391b74c..52ede954a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -614,7 +614,7 @@ public class OpExecutionerTests extends BaseNd4jTest { assertEquals(arr5s.getDouble(i), 16, 1e-1); System.out.println("6d"); INDArray arr6 = Nd4j.ones(1, 1, 4, 4, 4, 4); - INDArray arr6Tad = arr6.javaTensorAlongDimension(0, 2, 3); + INDArray arr6Tad = arr6.tensorAlongDimension(0, 2, 3); INDArray arr6s = arr6.sum(2, 3); for (int i = 0; i < arr6s.length(); i++) assertEquals(arr6s.getDouble(i), 16, 1e-1); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index 261e1e300..3bef69c19 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -299,6 +300,26 @@ public class EmptyTests extends BaseNd4jTest { assertNotNull(result[0].shapeInfoDataBuffer().asLong()); } + @Test + public void testAllEmptyReduce(){ + INDArray x = Nd4j.createFromArray(true, true, true); + val all = new All(x); + all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction) + INDArray out = Nd4j.exec(all); + assertEquals(x, out); + } + + @Test + public void testEmptyNoop() { + val output = Nd4j.empty(DataType.LONG); + + val op = DynamicCustomOp.builder("noop") + .addOutputs(output) + .build(); + + Nd4j.exec(op); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java index f4d2bf80c..6f47d00da 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/TADTests.java @@ -79,7 +79,7 @@ public class TADTests extends BaseNd4jTest { int[] shape = new int[] {e, x}; Arrays.sort(shape); - INDArray assertion = array.javaTensorAlongDimension(0, shape); + INDArray assertion = array.tensorAlongDimension(0, shape); INDArray test = array.tensorAlongDimension(0, shape); assertEquals(assertion, test); @@ -101,7 +101,7 @@ public class TADTests extends BaseNd4jTest { Arrays.sort(shape); log.info("About to do shape: " + Arrays.toString(shape) + " for array of shape " + array.shapeInfoToString()); - INDArray assertion = array.javaTensorAlongDimension(0, shape); + INDArray assertion = array.tensorAlongDimension(0, shape); INDArray test = array.tensorAlongDimension(0, shape); assertEquals(assertion, test); //assertEquals(assertion.shapeInfoDataBuffer(), test.shapeInfoDataBuffer()); @@ -121,8 +121,8 @@ public class TADTests extends BaseNd4jTest { public void testMysteriousCrash() { INDArray arrayF = Nd4j.create(new int[] {1, 1, 4, 4}, 'f'); INDArray arrayC = Nd4j.create(new int[] {1, 1, 4, 4}, 'c'); - INDArray javaCTad = arrayC.javaTensorAlongDimension(0, 2, 3); - INDArray javaFTad = arrayF.javaTensorAlongDimension(0, 2, 3); + INDArray javaCTad = arrayC.tensorAlongDimension(0, 2, 3); + INDArray javaFTad = arrayF.tensorAlongDimension(0, 2, 3); Pair tadBuffersF = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(arrayF, 2, 3); Pair tadBuffersC = diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java index 596bf16a7..806cf4d08 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/concat/ConcatTestsC.java @@ -185,7 +185,7 @@ public class ConcatTestsC extends BaseNd4jTest { //ConcatV2, dim 1 second = Nd4j.linspace(24, 32, 8, Nd4j.dataType()).reshape('c', 2, 1, 4); for (int i = 0; i < second.tensorsAlongDimension(1); i++) { - INDArray secondTad = second.javaTensorAlongDimension(i, 1); + INDArray secondTad = second.tensorAlongDimension(i, 1); System.out.println(second.tensorAlongDimension(i, 1)); } diff --git a/nd4j/nd4j-common/pom.xml b/nd4j/nd4j-common/pom.xml index d423e8012..d75c82cf4 100644 --- a/nd4j/nd4j-common/pom.xml +++ b/nd4j/nd4j-common/pom.xml @@ -83,6 +83,12 @@ ${project.version} + + org.nd4j + guava + ${project.version} + + org.slf4j slf4j-api @@ -117,11 +123,6 @@ ${commons-compress.version} - - com.google.guava - guava - - commons-codec commons-codec diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java index f5ba64f54..72409e3ca 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/collection/IntArrayKeyMap.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.collection; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import lombok.Getter; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java index 1f5c1f9cb..e17fec5ce 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/primitives/AtomicDouble.java @@ -24,7 +24,7 @@ import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; @JsonSerialize(using = JsonSerializerAtomicDouble.class) @JsonDeserialize(using = JsonDeserializerAtomicDouble.class) -public class AtomicDouble extends com.google.common.util.concurrent.AtomicDouble { +public class AtomicDouble extends org.nd4j.shade.guava.util.concurrent.AtomicDouble { public AtomicDouble(){ this(0.0); @@ -40,7 +40,7 @@ public class AtomicDouble extends com.google.common.util.concurrent.AtomicDouble @Override public boolean equals(Object o){ - //NOTE: com.google.common.util.concurrent.AtomicDouble extends Number, hence this class extends number + //NOTE: org.nd4j.shade.guava.util.concurrent.AtomicDouble extends Number, hence this class extends number if(o instanceof Number){ return get() == ((Number)o).doubleValue(); } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index e51e75ce4..2fe33dfba 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -16,8 +16,8 @@ package org.nd4j.linalg.util; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java index e08605881..0c4377ea9 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/SynchronizedTable.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.util; -import com.google.common.collect.Table; +import org.nd4j.shade.guava.collect.Table; import java.util.Collection; import java.util.Map; diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java index 83799381b..f8fca14b5 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java @@ -1,6 +1,6 @@ package org.nd4j.resources.strumpf; -import com.google.common.io.Files; +import org.nd4j.shade.guava.io.Files; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java index e7c8d977b..1b420d6a4 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java @@ -16,7 +16,7 @@ package org.nd4j.parameterserver.distributed.transport; -import com.google.common.math.IntMath; +import org.nd4j.shade.guava.math.IntMath; import io.aeron.Aeron; import io.aeron.FragmentAssembler; import io.aeron.Publication; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java index a34720093..70228b987 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java @@ -16,7 +16,7 @@ package org.nd4j.parameterserver.distributed.v2.transport.impl; -import com.google.common.math.IntMath; +import org.nd4j.shade.guava.math.IntMath; import io.aeron.Aeron; import io.aeron.FragmentAssembler; import io.aeron.Publication; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index 0a75e6171..dd50f938e 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -47,16 +47,6 @@ nd4j-parameter-server ${project.version} - - com.typesafe.akka - akka-actor_2.11 - ${akka.version} - - - com.typesafe.akka - akka-slf4j_2.11 - ${akka.version} - joda-time joda-time diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java index e5152e327..e8b149f67 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/src/main/java/org/nd4j/parameterserver/status/play/StatusServer.java @@ -22,17 +22,14 @@ import org.nd4j.parameterserver.model.MasterStatus; import org.nd4j.parameterserver.model.ServerTypeJson; import org.nd4j.parameterserver.model.SlaveStatus; import org.nd4j.parameterserver.model.SubscriberState; +import play.BuiltInComponents; import play.Mode; -import play.libs.F; import play.libs.Json; -import play.mvc.Result; +import play.routing.Router; import play.routing.RoutingDsl; import play.server.Server; -import java.util.List; - import static play.libs.Json.toJson; -import static play.mvc.Controller.request; import static play.mvc.Results.ok; @@ -70,74 +67,35 @@ public class StatusServer { */ public static Server startServer(StatusStorage statusStorage, int statusServerPort) { log.info("Starting server on port " + statusServerPort); - RoutingDsl dsl = new RoutingDsl(); - dsl.GET("/ids/").routeTo(new F.Function0() { - - @Override - public Result apply() throws Throwable { - List ids = statusStorage.ids(); - return ok(toJson(ids)); - } - }); - - - dsl.GET("/state/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return ok(toJson(statusStorage.getState(Integer.parseInt(id)))); - } - }); - - dsl.GET("/opType/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return ok(toJson(ServerTypeJson.builder() - .type(statusStorage.getState(Integer.parseInt(id)).serverType()))); - } - }); - - - dsl.GET("/started/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return statusStorage.getState(Integer.parseInt(id)).isMaster() - ? ok(toJson(MasterStatus.builder() - .master(statusStorage.getState(Integer.parseInt(id)).getServerState()) - //note here that a responder is id + 1 - .responder(statusStorage - .getState(Integer.parseInt(id) + 1).getServerState()) - .responderN(statusStorage - .getState(Integer.parseInt(id)).getTotalUpdates()) - .build())) - : ok(toJson(SlaveStatus.builder() - .slave(statusStorage.getState(Integer.parseInt(id)).serverType()) - .build())); - } - }); - - - - dsl.GET("/connectioninfo/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - return ok(toJson(statusStorage.getState(Integer.parseInt(id)).getConnectionInfo())); - } - }); - - dsl.POST("/updatestatus/:id").routeTo(new F.Function() { - @Override - public Result apply(String id) throws Throwable { - SubscriberState subscriberState = Json.fromJson(request().body().asJson(), SubscriberState.class); - statusStorage.updateState(subscriberState); - return ok(toJson(subscriberState)); - } - }); - - Server server = Server.forRouter(dsl.build(), Mode.PROD, statusServerPort); - - return server; - + return Server.forRouter(Mode.PROD, statusServerPort, builtInComponents -> createRouter(statusStorage, builtInComponents)); } + protected static Router createRouter(StatusStorage statusStorage, BuiltInComponents builtInComponents){ + RoutingDsl dsl = RoutingDsl.fromComponents(builtInComponents); + dsl.GET("/ids/").routingTo(request -> ok(toJson(statusStorage.ids()))); + dsl.GET("/state/:id").routingTo((request, id) -> ok(toJson(statusStorage.getState(Integer.parseInt(id.toString()))))); + dsl.GET("/opType/:id").routingTo((request, id) -> ok(toJson(ServerTypeJson.builder() + .type(statusStorage.getState(Integer.parseInt(id.toString())).serverType())))); + dsl.GET("/started/:id").routingTo((request, id) -> { + boolean isMaster = statusStorage.getState(Integer.parseInt(id.toString())).isMaster(); + if(isMaster){ + return ok(toJson(MasterStatus.builder().master(statusStorage.getState(Integer.parseInt(id.toString())).getServerState()) + //note here that a responder is id + 1 + .responder(statusStorage.getState(Integer.parseInt(id.toString()) + 1).getServerState()) + .responderN(statusStorage.getState(Integer.parseInt(id.toString())).getTotalUpdates()) + .build())); + } else { + return ok(toJson(SlaveStatus.builder().slave(statusStorage.getState(Integer.parseInt(id.toString())).serverType()).build())); + } + }); + dsl.GET("/connectioninfo/:id").routingTo((request, id) -> ok(toJson(statusStorage.getState(Integer.parseInt(id.toString())).getConnectionInfo()))); + dsl.POST("/updatestatus/:id").routingTo((request, id) -> { + SubscriberState subscriberState = Json.fromJson(request.body().asJson(), SubscriberState.class); + statusStorage.updateState(subscriberState); + return ok(toJson(subscriberState)); + }); + + return dsl.build(); + } } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index f69b9c24b..e5ccca909 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -20,7 +20,7 @@ import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import com.beust.jcommander.Parameters; -import com.google.common.primitives.Ints; +import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.jackson.databind.ObjectMapper; import com.mashape.unirest.http.Unirest; diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml index 1f668eb66..aa60e9586 100644 --- a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml +++ b/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml @@ -44,13 +44,6 @@ test - - org.nd4j - nd4j-native - ${project.version} - test - - org.nd4j nd4j-api @@ -95,6 +88,30 @@ + + nd4j-tests-cpu + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + + testresources diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index cc19f4f6d..47ef995a9 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -123,46 +123,15 @@ activation 1.1 + + + com.google.code.gson + gson + ${gson.version} + test + - - - - nd4j-tests-cpu - - true - - - - org.nd4j - nd4j-native - ${project.version} - test - - - - - - nd4j-tests-cuda - - false - - - - org.nd4j - nd4j-cuda-10.1 - ${project.version} - test - - - - - - testresources - - - - @@ -175,4 +144,34 @@ + + + + nd4j-tests-cpu + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + + + + testresources + + diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java index 426131d4d..8ea228586 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java @@ -16,7 +16,7 @@ package org.nd4j.aeron.ipc.chunk; -import com.google.common.collect.Maps; +import org.nd4j.shade.guava.collect.Maps; import lombok.extern.slf4j.Slf4j; import org.nd4j.aeron.ipc.NDArrayMessage; diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index 99f744587..4e4ba462e 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -49,36 +49,6 @@ joda-time ${jodatime.version} - - com.fasterxml.jackson.core - jackson-core - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${spark2.jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${spark2.jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-xml - ${spark2.jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-joda - ${spark2.jackson.version} - org.apache.arrow arrow-vector diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index 1e8c0f509..82770d51a 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -21,9 +21,6 @@ org.nd4j 1.0.0-SNAPSHOT - - 2.8.0 - 4.0.0 nd4j-gson diff --git a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java b/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java index fc04af6d2..5b087a055 100644 --- a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java +++ b/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java @@ -16,8 +16,8 @@ package org.nd4j.serde.gson; -import com.google.common.primitives.Ints; -import com.google.common.primitives.Longs; +import org.nd4j.shade.guava.primitives.Ints; +import org.nd4j.shade.guava.primitives.Longs; import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonParser; diff --git a/nd4j/nd4j-shade/guava/pom.xml b/nd4j/nd4j-shade/guava/pom.xml new file mode 100644 index 000000000..73b8d5825 --- /dev/null +++ b/nd4j/nd4j-shade/guava/pom.xml @@ -0,0 +1,219 @@ + + + + nd4j-shade + org.nd4j + 1.0.0-SNAPSHOT + + 4.0.0 + + guava + + + true + + + + + com.google.guava + guava + 28.0-jre + + true + + + + + + + custom-lifecycle + + + !skip.custom.lifecycle + + + + + + org.apache.portals.jetspeed-2 + jetspeed-mvn-maven-plugin + 2.3.1 + + + compile-and-pack + compile + + mvn + + + + + + org.apache.maven.shared + maven-invoker + 2.2 + + + + + + + create-shaded-jars + @rootdir@/nd4j/nd4j-shade/guava/ + clean,compile,package + + true + + + + + create-shaded-jars + + + + + + + + + + + + + com.lewisd + lint-maven-plugin + 0.0.11 + + + pom-lint + none + + + + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade-plugin.version} + + + package + + shade + + + + + reference.conf + + + + + + + + + + + false + true + true + + + + com.google.*:* + + + + + + com.google.common + org.nd4j.shade.guava + + + com.google + org.nd4j.shade + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + true + + + + empty-javadoc-jar + package + + jar + + + javadoc + ${basedir}/javadoc + + + + empty-sources-jar + package + + jar + + + sources + ${basedir}/src + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + 3.0.0 + + + unpack + package + + unpack + + + + + org.nd4j + guava + ${project.version} + jar + false + ${project.build.directory}/classes/ + **/*.class,**/*.xml + + + + + + + + + + \ No newline at end of file diff --git a/nd4j/nd4j-shade/jackson/pom.xml b/nd4j/nd4j-shade/jackson/pom.xml index 1d53e1c41..ad2be71fb 100644 --- a/nd4j/nd4j-shade/jackson/pom.xml +++ b/nd4j/nd4j-shade/jackson/pom.xml @@ -32,6 +32,79 @@ true + + + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + + true + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.databind.version} + true + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${jackson.version} + true + + + com.fasterxml.jackson.dataformat + jackson-dataformat-xml + ${jackson.version} + true + + + + + jackson-module-jaxb-annotations + com.fasterxml.jackson.module + + + + + com.fasterxml.jackson.datatype + jackson-datatype-joda + ${jackson.version} + true + + + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + true + + + org.yaml + snakeyaml + ${shaded.snakeyaml.version} + true + + + org.codehaus.woodstox + stax2-api + 3.1.4 + true + + + com.fasterxml.woodstox + woodstox-core + 5.1.0 + true + + + + + custom-lifecycle @@ -139,6 +212,9 @@ com.fasterxml.jackson:* com.fasterxml.jackson.*:* + com.fasterxml.woodstox:* + org.yaml*:* + org.codehaus*:* @@ -148,6 +224,20 @@ com.fasterxml.jackson org.nd4j.shade.jackson + + com.ctc.wstx + org.nd4j.shade.wstx + + + + org.yaml + org.nd4j.shade.yaml + + + + org.codehaus + org.nd4j.shade.codehaus + @@ -214,45 +304,4 @@ - - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-yaml - ${jackson.version} - - - com.fasterxml.jackson.dataformat - jackson-dataformat-xml - ${jackson.version} - - - - - jackson-module-jaxb-annotations - com.fasterxml.jackson.module - - - - - com.fasterxml.jackson.datatype - jackson-datatype-joda - ${jackson.version} - - - - - - diff --git a/nd4j/nd4j-shade/pom.xml b/nd4j/nd4j-shade/pom.xml index 36b58087b..927292b5f 100644 --- a/nd4j/nd4j-shade/pom.xml +++ b/nd4j/nd4j-shade/pom.xml @@ -30,6 +30,7 @@ jackson protobuf + guava diff --git a/nd4j/nd4j-shade/protobuf/pom.xml b/nd4j/nd4j-shade/protobuf/pom.xml index 1cbd7d5a8..910003683 100644 --- a/nd4j/nd4j-shade/protobuf/pom.xml +++ b/nd4j/nd4j-shade/protobuf/pom.xml @@ -20,11 +20,26 @@ com.google.protobuf protobuf-java 3.8.0 + + true com.google.protobuf protobuf-java-util 3.8.0 + true + + + com.google.guava + guava + + + + + com.google.guava + guava + 26.0-android + true
@@ -150,6 +165,7 @@ com.google.protobuf:* com.google.protobuf.*:* + com.google.guava:* @@ -159,6 +175,11 @@ com.google.protobuf org.nd4j.shade.protobuf + + + com.google.common + org.nd4j.shade.protobuf.common + diff --git a/nd4j/pom.xml b/nd4j/pom.xml index 6c294d7e7..f043d7299 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -77,11 +77,6 @@ slf4j-api ${slf4j.version} - - com.google.guava - guava - ${guava.version} - junit junit diff --git a/nd4s/build.sbt b/nd4s/build.sbt index 701483d0f..1fbac5ae6 100644 --- a/nd4s/build.sbt +++ b/nd4s/build.sbt @@ -38,7 +38,7 @@ lazy val commonSettings = Seq( resolvers in ThisBuild ++= Seq(Opts.resolver.sonatypeSnapshots), nd4jVersion := sys.props.getOrElse("nd4jVersion", default = "1.0.0-SNAPSHOT"), libraryDependencies ++= Seq( - "com.nativelibs4java" %% "scalaxy-loops" % "0.3.4", +// "com.nativelibs4java" %% "scalaxy-loops" % "0.3.4", // "org.nd4j" % "nd4j-api" % nd4jVersion.value, // "org.nd4j" % "nd4j-native-platform" % nd4jVersion.value % Test, "org.scalatest" %% "scalatest" % "2.2.6" % Test, diff --git a/nd4s/pom.xml b/nd4s/pom.xml index df20a4423..f165cfae9 100644 --- a/nd4s/pom.xml +++ b/nd4s/pom.xml @@ -68,28 +68,23 @@ - - com.nativelibs4java - scalaxy-loops_${scala.binary.version} - 0.3.4 - org.nd4j nd4j-api ${nd4j.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - ch.qos.logback logback-classic ${logback.version} test + + junit + junit + ${junit.version} + test + org.scalatest scalatest_${scala.binary.version} @@ -105,7 +100,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.12 + ${breeze.version} test @@ -193,6 +188,7 @@ -deprecation -explaintypes -nobootcp + -usejavacp @@ -295,4 +291,42 @@ + + + + test-nd4j-native + + + org.nd4j + nd4j-native + ${project.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${dl4j-test-resources.version} + test + + + + + + test-nd4j-cuda-10.1 + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${dl4j-test-resources.version} + test + + + + diff --git a/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala b/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala index 46bba7fff..bdbe43bc3 100644 --- a/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala +++ b/nd4s/src/main/scala/org/nd4s/CollectionLikeNDArray.scala @@ -21,7 +21,6 @@ import org.nd4j.linalg.api.ops.Op import org.nd4j.linalg.factory.Nd4j import org.nd4s.ops.{ BitFilterOps, FilterOps, FunctionalOpExecutioner, MapOps } -import scalaxy.loops._ import scala.language.postfixOps import scala.util.control.Breaks._ @@ -65,7 +64,7 @@ trait CollectionLikeNDArray[A <: INDArray] { val lv = ev.linearView(underlying) breakable { for { - i <- 0 until lv.length().toInt optimized + i <- 0 until lv.length().toInt } if (!f(ev.get(lv, i))) { result = true break() @@ -81,7 +80,7 @@ trait CollectionLikeNDArray[A <: INDArray] { val lv = ev.linearView(underlying) breakable { for { - i <- 0 until lv.length().toInt optimized + i <- 0 until lv.length().toInt } if (!f(ev.get(lv, i))) { result = false break() diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index 8ca21b72e..e43a1e86e 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -90,8 +90,13 @@ class SDVariableWrapper { def |(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.or(thisVariable, other) def &(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.and(thisVariable, other) - def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, x) - def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShiftRight(null, thisVariable, x) + def <<(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, other) + def >>(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.bitShiftRight(null, thisVariable, other) + def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.bitShift(null, thisVariable, sameDiff.constant(x)) + def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.bitShiftRight(null, thisVariable, sameDiff.constant(x)) // Overloads for numeric arguments // Float diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index a2c113b50..8e9f892c6 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -188,4 +188,16 @@ class MathTest extends FlatSpec with Matchers { val w3 = w1 >> 2 w3.eval.toIntVector.head shouldBe 4 } + + "SameDiff" should "provide shifting operations with SDVariable argument" in { + implicit val sd = SameDiff.create() + val w1 = sd.constant(16) + val two = sd.constant(2) + + val w2 = w1 << two + w2.eval.toIntVector.head shouldBe 64 + + val w3 = w1 >> two + w3.eval.toIntVector.head shouldBe 4 + } } diff --git a/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala index c9ee41bfb..d0efee304 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala @@ -4,7 +4,7 @@ import java.lang.reflect.Field import java.util import java.util.{ Arrays, Collections, HashMap, List, Map } -import com.google.common.collect.{ Lists, Maps } +import org.nd4j.shade.guava.collect.{ Lists, Maps } import org.junit.Assert._ import org.junit.Assume.assumeNotNull import org.nd4j.autodiff.samediff._ @@ -15,7 +15,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose import org.nd4j.linalg.api.buffer.DataType import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ops.DynamicCustomOp -import org.nd4j.linalg.api.ops.impl.layers.{ ExternalErrorsFunction, Linear } +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig } import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray diff --git a/pom.xml b/pom.xml index ee24279d4..63054c104 100644 --- a/pom.xml +++ b/pom.xml @@ -261,15 +261,13 @@ 3.4.2 0.8.2.2 - 2.3.16 - 1.3.0 + 1.3.0 0.10.4 1.27 0.8.0 2.2 1.15 3.17 - 2.4.8 0.5.0 2.3.23 2.8.1 @@ -280,6 +278,7 @@ 6.5.7 1.4.9 0.9.10 + 1.0 false false @@ -308,7 +307,7 @@ 1.14.0 ${tensorflow.version}-${javacpp-presets.version} - 1.16.1 + 1.18 3.5 3.6 2.5 @@ -325,13 +324,15 @@ 3.2.2 4.1 + 2.4.3 + 2 2.0.29 1.7.21 4.12 1.2.3 - 2.5.1 - ${jackson.version} - 2.6.5 + 2.9.9 + 2.9.9.3 + 1.23 2.8.7 1.18.2 2.0.0 @@ -339,18 +340,16 @@ 20131018 2.6.1 false - 2.2.0 - - 2.1.0 + 2.2.0 2.16.3 3.4.6 0.5.4 3.0.5 3.15.1 - 2.4.8 - + 2.7.3 2.0 - 20.0 + 28.0-jre + 2.8.0 1.2.0-3f79e055 4.10.0 @@ -386,12 +385,12 @@ 2.2.6 - - 2.10.7 - 2.10 2.11.12 2.11 + + 2.12.9 + 2.12 3.0.5 1.3.0 diff --git a/pydl4j/pom.xml b/pydl4j/pom.xml index d64c9bbdb..279c57e07 100644 --- a/pydl4j/pom.xml +++ b/pydl4j/pom.xml @@ -44,12 +44,13 @@ false 0.1.3 + nd4j-native org.nd4j - nd4j-native + ${nd4j.backend} ${dl4j.version} @@ -179,4 +180,18 @@ + + + + nd4j-backend + + + libnd4j.cuda + + + + nd4j-cuda-${libnd4j.cuda} + + + diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index 1831700b2..67b050ba1 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -91,12 +91,6 @@ deeplearning4j-core ${dl4j.version} - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.apache.commons commons-collections4 @@ -107,5 +101,37 @@ jackson-databind ${jackson.version} + + + com.google.code.gson + gson + ${gson.version} + + + + + nd4j-tests-cpu + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + + diff --git a/scalnet/pom.xml b/scalnet/pom.xml index 0f83a9b69..a6e220280 100644 --- a/scalnet/pom.xml +++ b/scalnet/pom.xml @@ -61,12 +61,6 @@ - - org.nd4j - nd4j-native - ${nd4j.version} - test - org.scala-lang scala-library @@ -94,6 +88,12 @@ ${scalacheck.version} test + + + com.google.code.gson + gson + ${gson.version} + @@ -274,4 +274,30 @@ + + + + nd4j-tests-cpu + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + + org.nd4j + nd4j-cuda-10.1 + ${project.version} + test + + + +