Development updates (#9053)
* RL4J: Add generic update rule (#502) Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Shyrma reduce (#481) * - start working on improving of cpu legacy code for reduce ops Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on improving legacy loops Signed-off-by: Yurii <iuriish@yahoo.com> * - still working on improving reduce ops Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on improving reduce ops Signed-off-by: Yurii <iuriish@yahoo.com> * - testing speed run of new reduce op Signed-off-by: Yurii <iuriish@yahoo.com> * - working on improvement of default loop for reduce op Signed-off-by: Yurii <iuriish@yahoo.com> * - update signatures of stuff which calls reduce ops Signed-off-by: Yurii <iuriish@yahoo.com> * - make corrections in cuda reduce kernels Signed-off-by: Yurii <iuriish@yahoo.com> * - change loop for default case in broadcast legacy ops Signed-off-by: Yurii <iuriish@yahoo.com> * - comment some shape stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - comment unnecessary prints in RNGtests Signed-off-by: Yurii <iuriish@yahoo.com> * - finish to resolve conflicts after master has been merged Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of some compilation mistakes of cuda stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - minor changes Signed-off-by: Yurii <iuriish@yahoo.com> * - further search for bug causing crash on java test Signed-off-by: Yurii <iuriish@yahoo.com> * - add scalar case in reduce_ ... exec stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - minor corrections in NAtiveOps.cu Signed-off-by: Yurii <iuriish@yahoo.com> * - add switch to scalar case execReduceXD functions Signed-off-by: Yurii <iuriish@yahoo.com> * - add support for vectors old shape in ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce Signed-off-by: Yurii <iuriish@yahoo.com> * - correct cuda mirrorPad Signed-off-by: Yurii <iuriish@yahoo.com> * - add support for vectors old shape in cuda createShapeInfoWithNoUnitiesForReduce Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com> * Add support for CUDA 11.0 (#492) * Add support for CUDA 11.0 * libnd4j tweaks for CUDA 11 Signed-off-by: raver119@gmail.com <raver119@gmail.com> * bindings update, again? Signed-off-by: raver119@gmail.com <raver119@gmail.com> * * Update versions of JavaCPP Presets for FFmpeg, OpenBLAS, and NumPy * update API to match CUDA 8 Signed-off-by: raver119@gmail.com <raver119@gmail.com> * * Update version of JavaCPP Presets for CPython * C++ updated for cuDNN 8.0 Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more test Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more test Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more test Signed-off-by: raver119@gmail.com <raver119@gmail.com> * 128-bit alignment for workspaces Signed-off-by: raver119@gmail.com <raver119@gmail.com> * change seed in 1 test Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fix dependecy duplication in python4j-parent pom * Fix group id for in python4j-numpy * few tests tweaked Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Remove macosx-x86_64-gpu from nd4j-tests-tensorflow * few minor tweaks for IndexReduce Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one test removed Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com> * RL4J: Add SyncTrainer and AgentLearnerBuilder for a few algorithms (#504) Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> Co-authored-by: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com> Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com>master
parent
ef665bfe49
commit
029b84e2b7
|
@ -99,7 +99,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -77,7 +77,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -63,7 +63,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,7 @@
|
||||||
<skip>${skipTestResourceEnforcement}</skip>
|
<skip>${skipTestResourceEnforcement}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
|
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -333,11 +333,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
VALID_VERSIONS=( 9.2 10.0 10.1 10.2 )
|
VALID_VERSIONS=( 9.2 10.0 10.1 10.2 11.0 )
|
||||||
|
|
||||||
usage() {
|
usage() {
|
||||||
echo "Usage: $(basename $0) [-h|--help] <cuda version to be used>
|
echo "Usage: $(basename $0) [-h|--help] <cuda version to be used>
|
||||||
|
@ -47,6 +47,10 @@ check_cuda_version() {
|
||||||
check_cuda_version "$VERSION"
|
check_cuda_version "$VERSION"
|
||||||
|
|
||||||
case $VERSION in
|
case $VERSION in
|
||||||
|
11.0)
|
||||||
|
VERSION2="8.0"
|
||||||
|
VERSION3="1.5.4-SNAPSHOT"
|
||||||
|
;;
|
||||||
10.2)
|
10.2)
|
||||||
VERSION2="7.6"
|
VERSION2="7.6"
|
||||||
VERSION3="1.5.3"
|
VERSION3="1.5.3"
|
||||||
|
|
|
@ -126,7 +126,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -62,7 +62,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -79,7 +79,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -66,7 +66,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -128,7 +128,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -81,7 +81,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -56,7 +56,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -74,7 +74,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -54,7 +54,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -72,7 +72,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -95,7 +95,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -78,7 +78,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -60,7 +60,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -59,7 +59,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -178,7 +178,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -38,7 +38,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -144,7 +144,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -108,7 +108,7 @@
|
||||||
<skip>${skipTestResourceEnforcement}</skip>
|
<skip>${skipTestResourceEnforcement}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
|
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -361,11 +361,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -62,11 +62,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -40,7 +40,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -180,11 +180,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
<artifactId>deeplearning4j-cuda-10.2</artifactId>
|
<artifactId>deeplearning4j-cuda-11.0</artifactId>
|
||||||
<name>deeplearning4j-cuda</name>
|
<name>deeplearning4j-cuda</name>
|
||||||
<parent>
|
<parent>
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
@ -26,9 +26,9 @@
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
<!-- CUDA version is linked with the artifact name so cannot move to parent pom.xml -->
|
||||||
<cuda.version>10.2</cuda.version>
|
<cuda.version>11.0</cuda.version>
|
||||||
<cudnn.version>7.6</cudnn.version>
|
<cudnn.version>8.0</cudnn.version>
|
||||||
<javacpp-presets.cuda.version>1.5.3</javacpp-presets.cuda.version>
|
<javacpp-presets.cuda.version>1.5.4-SNAPSHOT</javacpp-presets.cuda.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencyManagement>
|
<dependencyManagement>
|
||||||
|
@ -112,7 +112,7 @@
|
||||||
</build>
|
</build>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -254,17 +254,34 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
|
throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
/*
|
||||||
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
|
||||||
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
|
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
|
||||||
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
|
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
|
||||||
0, algo1);
|
0, algo1);
|
||||||
|
*/
|
||||||
|
val fa = new cudnnConvolutionBwdFilterAlgoPerf_t();
|
||||||
|
val counts = new int[1];
|
||||||
|
code = cudnnFindConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, 1, counts, fa);
|
||||||
|
algo1[0] = fa.algo();
|
||||||
|
|
||||||
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
|
|
||||||
|
/*
|
||||||
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
|
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
|
||||||
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
|
||||||
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
|
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
|
||||||
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
|
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
|
||||||
0, algo2);
|
0, algo2);
|
||||||
|
*/
|
||||||
|
|
||||||
|
val da = new cudnnConvolutionBwdDataAlgoPerf_t();
|
||||||
|
code = cudnnFindConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
|
||||||
|
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, 1, counts, da);
|
||||||
|
|
||||||
|
algo2[0] = da.algo();
|
||||||
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -461,11 +478,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
|
throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
code = cudnnGetConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
|
/*
|
||||||
|
code = cudnnGetConvolutionForwardAlgorithm_v7(cudnnContext, cudnnContext.srcTensorDesc,
|
||||||
cudnnContext.filterDesc, cudnnContext.convDesc,
|
cudnnContext.filterDesc, cudnnContext.convDesc,
|
||||||
cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE
|
cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE
|
||||||
? CUDNN_CONVOLUTION_FWD_NO_WORKSPACE : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
|
? CUDNN_CONVOLUTION_FWD_ : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
|
||||||
0, algo);
|
0, algo);
|
||||||
|
*/
|
||||||
|
|
||||||
|
val cdf = new cudnnConvolutionFwdAlgoPerf_t();
|
||||||
|
val count = new int[1];
|
||||||
|
code = cudnnFindConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, 1, count, cdf);
|
||||||
|
|
||||||
if(code != CUDNN_STATUS_SUCCESS){
|
if(code != CUDNN_STATUS_SUCCESS){
|
||||||
//If CuDNN can't infer algorithm - try IMPLICIT_GEMM
|
//If CuDNN can't infer algorithm - try IMPLICIT_GEMM
|
||||||
|
@ -477,6 +500,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
|
||||||
fwdAlgo = FwdAlgo.IMPLICIT_GEMM;
|
fwdAlgo = FwdAlgo.IMPLICIT_GEMM;
|
||||||
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
|
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
algo[0] = cdf.algo();
|
||||||
}
|
}
|
||||||
|
|
||||||
if(log.isTraceEnabled()){
|
if(log.isTraceEnabled()){
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -269,7 +270,7 @@ public class ValidateCudnnLSTM extends BaseDL4JTest {
|
||||||
assertTrue(f.get(l0) instanceof CudnnLSTMHelper);
|
assertTrue(f.get(l0) instanceof CudnnLSTMHelper);
|
||||||
assertTrue(f.get(l1) instanceof CudnnLSTMHelper);
|
assertTrue(f.get(l1) instanceof CudnnLSTMHelper);
|
||||||
|
|
||||||
Random r = new Random(12345);
|
Random r = new Random(123456);
|
||||||
for (int x = 0; x < 1; x++) {
|
for (int x = 0; x < 1; x++) {
|
||||||
INDArray input = Nd4j.rand(new int[] {minibatch, inputSize, timeSeriesLength});
|
INDArray input = Nd4j.rand(new int[] {minibatch, inputSize, timeSeriesLength});
|
||||||
INDArray labels = Nd4j.zeros(minibatch, nOut, timeSeriesLength);
|
INDArray labels = Nd4j.zeros(minibatch, nOut, timeSeriesLength);
|
||||||
|
@ -284,7 +285,6 @@ public class ValidateCudnnLSTM extends BaseDL4JTest {
|
||||||
mln2.fit(ds);
|
mln2.fit(ds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
assertEquals(mln1.params(), mln2.params());
|
assertEquals(mln1.params(), mln2.params());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -43,7 +43,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -38,7 +38,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -116,11 +116,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -64,7 +64,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -41,7 +41,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -302,11 +302,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -135,11 +135,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -125,11 +125,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -49,7 +49,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -53,7 +53,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -89,11 +89,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -44,7 +44,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -72,7 +72,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -110,7 +110,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -128,7 +128,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -107,14 +107,14 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>false</activeByDefault>
|
<activeByDefault>false</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -109,11 +109,11 @@
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -104,7 +104,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -77,7 +77,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -457,7 +457,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
|
@ -49,7 +49,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,7 @@
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
</profile>
|
</profile>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
</project>
|
</project>
|
|
@ -225,7 +225,7 @@
|
||||||
<skip>${skipBackendChoice}</skip>
|
<skip>${skipBackendChoice}</skip>
|
||||||
<rules>
|
<rules>
|
||||||
<requireActiveProfile>
|
<requireActiveProfile>
|
||||||
<profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
|
<profiles>test-nd4j-native,test-nd4j-cuda-11.0</profiles>
|
||||||
<all>false</all>
|
<all>false</all>
|
||||||
</requireActiveProfile>
|
</requireActiveProfile>
|
||||||
</rules>
|
</rules>
|
||||||
|
@ -500,7 +500,7 @@
|
||||||
</profile>
|
</profile>
|
||||||
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
<!-- For running unit tests with nd4j-cuda-8.0: "mvn clean test -P test-nd4j-cuda-8.0" -->
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-11.0</id>
|
||||||
<activation>
|
<activation>
|
||||||
<activeByDefault>false</activeByDefault>
|
<activeByDefault>false</activeByDefault>
|
||||||
</activation>
|
</activation>
|
||||||
|
@ -513,7 +513,7 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-cuda-10.2</artifactId>
|
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
|
@ -628,17 +628,17 @@ namespace sd {
|
||||||
* keepDims - if true then put unities in place of reduced dimensions
|
* keepDims - if true then put unities in place of reduced dimensions
|
||||||
*/
|
*/
|
||||||
|
|
||||||
NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::vector<int>& dimensions, const bool keepDims = false) const;
|
||||||
NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false) const;
|
||||||
|
|
||||||
NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::vector<int>& dimensions, const bool keepDims = false) const;
|
||||||
NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false) const;
|
||||||
|
|
||||||
NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::vector<int>& dimensions, const bool keepDims = false) const;
|
||||||
NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false) const;
|
||||||
|
|
||||||
NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::vector<int>& dimensions, const bool keepDims = false) const;
|
||||||
NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
|
NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list<int>& dimensions, const bool keepDims = false) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* method reduces array by excluding its shapes along dimensions present in given dimensions vector
|
* method reduces array by excluding its shapes along dimensions present in given dimensions vector
|
||||||
|
@ -647,10 +647,10 @@ namespace sd {
|
||||||
* keepDims - if true then put unities in place of reduced dimensions
|
* keepDims - if true then put unities in place of reduced dimensions
|
||||||
* extras - extra parameters
|
* extras - extra parameters
|
||||||
*/
|
*/
|
||||||
void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
|
||||||
void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
|
||||||
void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
|
||||||
void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
|
void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* return variance of array elements set
|
* return variance of array elements set
|
||||||
|
|
|
@ -1353,80 +1353,80 @@ void* NDArray::bufferWithOffset(Nd4jLong offset) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// eventually method reduces array by excluding its shapes along axes present in dimensions vector
|
// eventually method reduces array by excluding its shapes along axes present in dimensions vector
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector<int>& dimensions, const bool keepDims) const {
|
||||||
|
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, false, getContext()->getWorkspace());
|
||||||
|
|
||||||
NDArray result(newShape, true, getContext());
|
NDArray result(newShape, true, getContext());
|
||||||
|
|
||||||
this->reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
|
this->reduceAlongDimension(op, result, copy, keepDims, false);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector<int>& dimensions, const bool keepDims) const {
|
||||||
|
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, false, getContext()->getWorkspace());
|
||||||
|
|
||||||
NDArray result(newShape, true, getContext());
|
NDArray result(newShape, true, getContext());
|
||||||
|
|
||||||
reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
|
reduceAlongDimension(op, result, copy, keepDims, false);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector<int>& dimensions, const bool keepDims) const {
|
||||||
|
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, false, getContext()->getWorkspace());
|
||||||
|
|
||||||
NDArray result(newShape, true, getContext());
|
NDArray result(newShape, true, getContext());
|
||||||
|
|
||||||
reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
|
reduceAlongDimension(op, result, copy, keepDims, false);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector<int>& dimensions, const bool keepDims) const {
|
||||||
|
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, false, getContext()->getWorkspace());
|
||||||
|
|
||||||
NDArray result(newShape, true, getContext());
|
NDArray result(newShape, true, getContext());
|
||||||
|
|
||||||
reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
|
reduceAlongDimension(op, result, copy, keepDims, false);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// method reduces array by excluding its shapes along axes present in dimensions vector
|
// method reduces array by excluding its shapes along axes present in dimensions vector
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list<int>& dimensions, const bool keepDims) const {
|
||||||
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims, supportOldShapes);
|
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list<int>& dimensions, const bool keepDims) const {
|
||||||
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims, supportOldShapes);
|
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list<int>& dimensions, const bool keepDims) const {
|
||||||
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims, supportOldShapes);
|
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list<int>& dimensions, const bool keepDims, const bool supportOldShapes) const {
|
NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list<int>& dimensions, const bool keepDims) const {
|
||||||
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims, supportOldShapes);
|
return reduceAlongDimension(op, std::vector<int>(dimensions), keepDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -4240,7 +4240,7 @@ NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, cons
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// method reduces array by excluding its shapes along axes present in dimensions vector
|
// method reduces array by excluding its shapes along axes present in dimensions vector
|
||||||
void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
|
void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool checkTargetShape) const {
|
||||||
|
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!");
|
||||||
|
@ -4250,7 +4250,7 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, con
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
if(checkTargetShape) {
|
if(checkTargetShape) {
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
|
||||||
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!");
|
throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!");
|
||||||
}
|
}
|
||||||
|
@ -4261,8 +4261,18 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, con
|
||||||
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy);
|
const Nd4jLong* zShapeInfoH = target.shapeInfo();
|
||||||
NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
|
||||||
|
|
||||||
|
if(rankOf() - dimensions.size() != target.rankOf()) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
|
||||||
|
NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
|
||||||
|
|
||||||
}
|
}
|
||||||
synchronize("NDArray::reduceAlongDimension FloatOps");
|
synchronize("NDArray::reduceAlongDimension FloatOps");
|
||||||
|
|
||||||
|
@ -4271,7 +4281,7 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, con
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// method reduces array by excluding its shapes along axes present in dimensions vector
|
// method reduces array by excluding its shapes along axes present in dimensions vector
|
||||||
void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
|
void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool checkTargetShape) const {
|
||||||
|
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!");
|
||||||
|
@ -4281,7 +4291,7 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, cons
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
if(checkTargetShape) {
|
if(checkTargetShape) {
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
|
||||||
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!");
|
throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!");
|
||||||
}
|
}
|
||||||
|
@ -4291,10 +4301,19 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, cons
|
||||||
if(rankOf() == copy.size() || copy.empty()) {
|
if(rankOf() == copy.size() || copy.empty()) {
|
||||||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||||
}
|
}
|
||||||
else { //if (!isEmpty()) {
|
else {
|
||||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
|
||||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
|
const Nd4jLong* zShapeInfoH = target.shapeInfo();
|
||||||
NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
|
||||||
|
|
||||||
|
if(rankOf() - dimensions.size() != target.rankOf()) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
|
||||||
|
NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
|
||||||
}
|
}
|
||||||
synchronize("NDArray::reduceAlongDimension SameOps");
|
synchronize("NDArray::reduceAlongDimension SameOps");
|
||||||
|
|
||||||
|
@ -4303,7 +4322,7 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, cons
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// method reduces array by excluding its shapes along axes present in dimensions vector
|
// method reduces array by excluding its shapes along axes present in dimensions vector
|
||||||
void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
|
void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool checkTargetShape) const {
|
||||||
|
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!");
|
||||||
|
@ -4313,7 +4332,7 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, cons
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
if(checkTargetShape) {
|
if(checkTargetShape) {
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
|
||||||
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!");
|
throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!");
|
||||||
}
|
}
|
||||||
|
@ -4324,9 +4343,17 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, cons
|
||||||
NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
const Nd4jLong* zShapeInfoH = target.shapeInfo();
|
||||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
|
const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
|
||||||
NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
|
||||||
|
if(rankOf() - dimensions.size() != target.rankOf()) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
|
||||||
|
NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
|
||||||
}
|
}
|
||||||
synchronize("NDArray::reduceAlongDimension LongOps");
|
synchronize("NDArray::reduceAlongDimension LongOps");
|
||||||
|
|
||||||
|
@ -4335,7 +4362,7 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, cons
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// method reduces array by excluding its shapes along axes present in dimensions vector
|
// method reduces array by excluding its shapes along axes present in dimensions vector
|
||||||
void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
|
void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector<int>& dimensions, const bool keepDims, const bool checkTargetShape) const {
|
||||||
|
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!");
|
||||||
|
@ -4345,7 +4372,7 @@ void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, cons
|
||||||
std::vector<int> copy(dimensions);
|
std::vector<int> copy(dimensions);
|
||||||
|
|
||||||
if(checkTargetShape) {
|
if(checkTargetShape) {
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
|
auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
|
||||||
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
if(!shape::shapeEquals(newShape, target.shapeInfo()))
|
||||||
throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!");
|
throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!");
|
||||||
}
|
}
|
||||||
|
@ -4356,9 +4383,17 @@ void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, cons
|
||||||
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
|
const Nd4jLong* zShapeInfoH = target.shapeInfo();
|
||||||
auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
|
const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
|
||||||
NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
|
||||||
|
if(rankOf() - dimensions.size() != target.rankOf()) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
|
||||||
|
NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
|
||||||
}
|
}
|
||||||
synchronize("NDArray::reduceAlongDimension LongOps");
|
synchronize("NDArray::reduceAlongDimension LongOps");
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,8 @@ namespace sd {
|
||||||
ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
|
||||||
ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
|
ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
|
||||||
ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> &dimensions = {});
|
ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> &dimensions = {});
|
||||||
|
ConstantShapeBuffer& createShapeInfoWithNoUnitiesForReduce(const Nd4jLong* maxShapeInfo, const std::vector<int> &dimsWithUnities, sd::memory::Workspace* workspace = nullptr);
|
||||||
|
ConstantShapeBuffer& createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, sd::memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
|
||||||
const Nd4jLong* emptyShapeInfo(sd::DataType dataType);
|
const Nd4jLong* emptyShapeInfo(sd::DataType dataType);
|
||||||
|
|
|
@ -41,43 +41,43 @@ namespace sd {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop);
|
static FORCEINLINE void loopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, E* extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
class ReductionFloatLoops : public ReductionLoops<X, Z, Z> {
|
class ReductionFloatLoops : public ReductionLoops<X, Z, Z> {
|
||||||
public:
|
public:
|
||||||
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
|
static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, Z* extraParams);
|
||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
|
static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, Z* extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops<X, Z, X> {
|
class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops<X, Z, X> {
|
||||||
public:
|
public:
|
||||||
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
|
static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
|
||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
|
static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
class ND4J_EXPORT ReductionLongLoops : public ReductionLoops<X, Z, X> {
|
class ND4J_EXPORT ReductionLongLoops : public ReductionLoops<X, Z, X> {
|
||||||
public:
|
public:
|
||||||
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
|
static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
|
||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
|
static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
class ND4J_EXPORT ReductionSameLoops : public ReductionLoops<X, X, X> {
|
class ND4J_EXPORT ReductionSameLoops : public ReductionLoops<X, X, X> {
|
||||||
public:
|
public:
|
||||||
static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
|
static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, X* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
|
||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
|
static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, X* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,371 +122,612 @@ namespace sd {
|
||||||
static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
|
static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
|
static void reduceExec21(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0) {
|
||||||
|
|
||||||
|
auto x0 = x + i0 * xStrd0;
|
||||||
|
auto z0 = z + i0 * zStrd0;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x0);
|
||||||
|
|
||||||
|
if(xStrd1 == 1)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1 * xStrd1], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z0 = OpType::postProcess(s, static_cast<Nd4jLong>(xAxis1), extraParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,xAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
|
static void reduceExec31(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
|
||||||
|
|
||||||
|
const Nd4jLong tadLen = static_cast<Nd4jLong>(xAxis1 * xAxis2);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0) {
|
||||||
|
|
||||||
|
auto x0 = x + i0 * xStrd0;
|
||||||
|
auto z0 = z + i0 * zStrd0;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x0);
|
||||||
|
|
||||||
|
if(xStrd1 == 1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1 + i2*xStrd2], extraParams), extraParams);
|
||||||
|
else if(xStrd2 == 1)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z0 = OpType::postProcess(s, tadLen, extraParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,xAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
|
void reduceExec32(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
|
||||||
|
const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
|
||||||
|
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
|
|
||||||
/*
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
template<typename X, typename Y, typename Z>
|
|
||||||
void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
|
|
||||||
const Y* y, const Nd4jLong* yShapeInfo,
|
|
||||||
Z* z, const Nd4jLong* zShapeInfo,
|
|
||||||
Z* extraParams,
|
|
||||||
std::function<Z(X,Y,Z*)> op) {
|
|
||||||
|
|
||||||
const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo);
|
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
||||||
|
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
||||||
|
|
||||||
const Nd4jLong* xShape = shape::shapeOf(xShapeInfo);
|
auto s = OpType::startingValue(x1);
|
||||||
const Nd4jLong* xStride = shape::stride(xShapeInfo);
|
|
||||||
const Nd4jLong* yStride = shape::stride(yShapeInfo);
|
|
||||||
const Nd4jLong* zStride = shape::stride(zShapeInfo);
|
|
||||||
|
|
||||||
const Nd4jLong len = shape::length(xShapeInfo);
|
if(xStrd2 == 1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2 * xStrd2], extraParams), extraParams);
|
||||||
|
|
||||||
OmpLaunchHelper threadsInfo(len);
|
*z1 = OpType::postProcess(s, static_cast<Nd4jLong>(xAxis2), extraParams);
|
||||||
|
|
||||||
switch (kindOfLoop) {
|
|
||||||
|
|
||||||
case LoopKind::EWS1: {
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads)
|
|
||||||
{
|
|
||||||
const auto threadNum = omp_get_thread_num();
|
|
||||||
const auto threadOffset = threadsInfo.getThreadOffset(threadNum);
|
|
||||||
const auto lenPerThread = static_cast<uint>(threadsInfo.getItersPerThread(threadNum));
|
|
||||||
|
|
||||||
const auto xi = x + threadOffset;
|
|
||||||
const auto yi = y + threadOffset;
|
|
||||||
auto zi = z + threadOffset;
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (uint i = 0; i < lenPerThread; i++)
|
|
||||||
zi[i] = op(xi[i], yi[i], extraParams);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
};
|
||||||
|
|
||||||
case LoopKind::EWSNONZERO: {
|
samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1);
|
||||||
const uint xEws = shape::elementWiseStride(xShapeInfo);
|
}
|
||||||
const uint yEws = shape::elementWiseStride(yShapeInfo);
|
|
||||||
const uint zEws = shape::elementWiseStride(zShapeInfo);
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads)
|
//////////////////////////////////////////////////////////////////////////
|
||||||
{
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
const auto threadNum = omp_get_thread_num();
|
void reduceExec41(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
const auto threadOffset = threadsInfo.getThreadOffset(threadNum);
|
|
||||||
const auto lenPerThread = static_cast<uint>(threadsInfo.getItersPerThread(threadNum));
|
|
||||||
const auto xi = x + threadOffset * xEws;
|
|
||||||
const auto yi = y + threadOffset * yEws;
|
|
||||||
auto zi = z + threadOffset * zEws;
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
|
||||||
for (uint i = 0; i < lenPerThread; i++)
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
|
||||||
zi[i*zEws] = op(xi[i*xEws], yi[i*yEws], extraParams);
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
|
||||||
|
|
||||||
|
const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
|
||||||
|
const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
|
||||||
|
|
||||||
|
const Nd4jLong tadLen = static_cast<Nd4jLong>(xAxis1 * xAxis2 * xAxis3);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0) {
|
||||||
|
|
||||||
|
auto x0 = x + i0 * xStrd0;
|
||||||
|
auto z0 = z + i0 * zStrd0;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x0);
|
||||||
|
|
||||||
|
if(xStrd1 == 1)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1 + i2*xStrd2 + i3*xStrd3], extraParams), extraParams);
|
||||||
|
else if(xStrd2 == 1)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2 + i3*xStrd3], extraParams), extraParams);
|
||||||
|
else if(xStrd3 == 1)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3*xStrd3], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z0 = OpType::postProcess(s, tadLen, extraParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,xAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
|
void reduceExec42(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
|
||||||
|
const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
|
||||||
|
|
||||||
|
const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
|
||||||
|
const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
|
||||||
|
|
||||||
|
const Nd4jLong tadLen = static_cast<Nd4jLong>(xAxis2 * xAxis3);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
|
||||||
|
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
||||||
|
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x1);
|
||||||
|
|
||||||
|
if(xStrd2 == 1)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2 + i3*xStrd3], extraParams), extraParams);
|
||||||
|
else if(xStrd3 == 1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3*xStrd3], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z1 = OpType::postProcess(s, tadLen, extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
};
|
||||||
|
|
||||||
case LoopKind::RANK1: {
|
samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1);
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
}
|
||||||
for (uint i0 = 0; i0 < len; ++i0)
|
|
||||||
z[i0 * zStride[0]] = op(x[i0 * xStride[0]], y[i0 * yStride[0]], extraParams);
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
|
void reduceExec43(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
|
||||||
|
const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
|
||||||
|
const Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
|
||||||
|
const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
|
||||||
|
const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
|
||||||
|
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
||||||
|
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x2);
|
||||||
|
|
||||||
|
if(xStrd3 == 1)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x2[i3], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x2[i3*xStrd3], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z2 = OpType::postProcess(s, static_cast<Nd4jLong>(xAxis3), extraParams);
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
|
|
||||||
case LoopKind::RANK2: {
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (uint i0 = 0; i0 < xShape[0]; ++i0)
|
|
||||||
for (uint i1 = 0; i1 < xShape[1]; ++i1)
|
|
||||||
z[i0 * zStride[0] + i1 * zStride[1]] = op(x[i0 * xStride[0] + i1 * xStride[1]], y[i0 * yStride[0] + i1 * yStride[1]], extraParams);
|
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
|
|
||||||
case LoopKind::RANK3: {
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2)
|
|
||||||
for (uint i0 = 0; i0 < xShape[0]; ++i0)
|
|
||||||
for (uint i1 = 0; i1 < xShape[1]; ++i1)
|
|
||||||
for (uint i2 = 0; i2 < xShape[2]; ++i2)
|
|
||||||
z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]], extraParams);
|
|
||||||
}
|
}
|
||||||
break;
|
};
|
||||||
|
|
||||||
case LoopKind::RANK4: {
|
samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1, 0,xAxis2,1);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(3)
|
}
|
||||||
for (uint i0 = 0; i0 < xShape[0]; ++i0)
|
|
||||||
for (uint i1 = 0; i1 < xShape[1]; ++i1)
|
//////////////////////////////////////////////////////////////////////////
|
||||||
for (uint i2 = 0; i2 < xShape[2]; ++i2)
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
for (uint i3 = 0; i3 < xShape[3]; ++i3)
|
void reduceExec51(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]], extraParams);
|
|
||||||
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
|
||||||
|
|
||||||
|
const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
|
||||||
|
const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
|
||||||
|
|
||||||
|
const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
|
||||||
|
const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
|
||||||
|
|
||||||
|
const Nd4jLong tadLen = static_cast<Nd4jLong>(xAxis1 * xAxis2 * xAxis3 * xAxis4);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0) {
|
||||||
|
|
||||||
|
auto x0 = x + i0 * xStrd0;
|
||||||
|
auto z0 = z + i0 * zStrd0;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x0);
|
||||||
|
|
||||||
|
if(xStrd1 == 1)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1 + i2*xStrd2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
else if(xStrd2 == 1)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
else if(xStrd3 == 1)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
else if(xStrd4 == 1)
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3*xStrd3 + i4], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i1 = 0; i1 < xAxis1; ++i1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z0 = OpType::postProcess(s, tadLen, extraParams);
|
||||||
}
|
}
|
||||||
break;
|
};
|
||||||
|
|
||||||
case LoopKind::RANK5: {
|
samediff::Threads::parallel_for(func, 0,xAxis0);
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(4)
|
}
|
||||||
for (uint i0 = 0; i0 < xShape[0]; ++i0)
|
|
||||||
for (uint i1 = 0; i1 < xShape[1]; ++i1)
|
//////////////////////////////////////////////////////////////////////////
|
||||||
for (uint i2 = 0; i2 < xShape[2]; ++i2)
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
for (uint i3 = 0; i3 < xShape[3]; ++i3)
|
void reduceExec52(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
for (uint i4 = 0; i4 < xShape[4]; ++i4)
|
|
||||||
z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]+i4*zStride[4]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]+i4*xStride[4]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]+i4*yStride[4]], extraParams);
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
|
||||||
|
const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
|
||||||
|
|
||||||
|
const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
|
||||||
|
const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
|
||||||
|
|
||||||
|
const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
|
||||||
|
const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
|
||||||
|
|
||||||
|
const Nd4jLong tadLen = static_cast<Nd4jLong>(xAxis2 * xAxis3 * xAxis4);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
|
||||||
|
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
||||||
|
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x1);
|
||||||
|
|
||||||
|
if(xStrd2 == 1)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
else if(xStrd3 == 1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
else if(xStrd4 == 1)
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3*xStrd3 + i4], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i2 = 0; i2 < xAxis2; ++i2)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z1 = OpType::postProcess(s, tadLen, extraParams);
|
||||||
}
|
}
|
||||||
break;
|
}
|
||||||
|
};
|
||||||
|
|
||||||
default: {
|
samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1);
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
}
|
||||||
uint yShapeInfoCast[MAX_RANK];
|
|
||||||
uint zShapeInfoCast[MAX_RANK];
|
|
||||||
|
|
||||||
bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool canCastY = DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
void reduceExec53(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads)
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
|
||||||
{
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
|
||||||
auto threadNum = omp_get_thread_num();
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
auto threadOffset = threadsInfo.getThreadOffset(threadNum);
|
|
||||||
auto lenPerThread = static_cast<uint>(threadsInfo.getItersPerThread(threadNum));
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
|
||||||
PRAGMA_OMP_SIMD
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
|
||||||
for (uint i = 0; i < lenPerThread; i++) {
|
const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
|
|
||||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
|
||||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
|
||||||
z[zOffset] = op(x[xOffset], y[yOffset], extraParams);
|
const Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
|
||||||
|
const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
|
||||||
|
const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
|
||||||
|
|
||||||
|
const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
|
||||||
|
const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
|
||||||
|
|
||||||
|
const Nd4jLong tadLen = static_cast<Nd4jLong>(xAxis3 * xAxis4);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
|
||||||
|
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
||||||
|
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x2);
|
||||||
|
|
||||||
|
if(xStrd3 == 1)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
s = OpType::update(s, OpType::op(x2[i3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
else if(xStrd4 == 1)
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x2[i3*xStrd3 + i4], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i3 = 0; i3 < xAxis3; ++i3)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x2[i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z2 = OpType::postProcess(s, tadLen, extraParams);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1, 0,xAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
|
void reduceExec54(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
|
const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[3]);
|
||||||
|
const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[3]);
|
||||||
|
const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
|
||||||
|
const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[2]);
|
||||||
|
const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[2]);
|
||||||
|
const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
|
||||||
|
const uint xAxis2 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[1]);
|
||||||
|
const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[1]);
|
||||||
|
const Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
|
||||||
|
const uint xAxis3 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[3] : dims[0]);
|
||||||
|
const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[3] : dims[0]);
|
||||||
|
const Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
|
||||||
|
const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
|
||||||
|
const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
for (auto i3 = 0; i3 < xAxis3; ++i3) {
|
||||||
|
|
||||||
|
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
|
||||||
|
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
|
||||||
|
|
||||||
|
auto s = OpType::startingValue(x3);
|
||||||
|
|
||||||
|
if(xStrd4 == 1)
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x3[i4], extraParams), extraParams);
|
||||||
|
else
|
||||||
|
for (uint i4 = 0; i4 < xAxis4; ++i4)
|
||||||
|
s = OpType::update(s, OpType::op(x3[i4*xStrd4], extraParams), extraParams);
|
||||||
|
|
||||||
|
*z3 = OpType::postProcess(s, static_cast<Nd4jLong>(xAxis4), extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
*/
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1, 0,xAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename E, typename OpType>
|
||||||
|
void reduceDefault(sd::memory::Workspace* workspace, const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
const int zRank = shape::rank(zShapeInfo);
|
||||||
template<typename X, typename Z, typename E>
|
const int tadRank = shape::rank(xShapeInfo) - zRank;
|
||||||
template <typename OpType>
|
|
||||||
void sd::ReductionLoops<X, Z, E>::loopReduce(const X* x, const Nd4jLong* xShapeInfo,
|
|
||||||
Z* z, const Nd4jLong* zShapeInfo,
|
|
||||||
const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets,
|
|
||||||
E* extraParams,
|
|
||||||
int64_t start, int64_t stop) {
|
|
||||||
|
|
||||||
const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo);
|
Nd4jLong* outerXTadShapeInfo = sd::ShapeBuilders::createSubArrShapeInfo(xShapeInfo, dims, zRank);
|
||||||
|
Nd4jLong* innerXTadShapeInfo = sd::ShapeBuilders::createSubArrShapeInfo(xShapeInfo, dims+zRank, tadRank);
|
||||||
|
|
||||||
|
const bool sameOffsets1 = shape::haveSameShapeAndStrides(zShapeInfo, outerXTadShapeInfo);
|
||||||
|
const bool sameOffsets2 = shape::haveSameShapeAndStrides(zShapeInfo, innerXTadShapeInfo);
|
||||||
|
|
||||||
const Nd4jLong zLen = shape::length(zShapeInfo);
|
const Nd4jLong zLen = shape::length(zShapeInfo);
|
||||||
const Nd4jLong tadLen = shape::length(tadShapeInfo);
|
const Nd4jLong tadLen = shape::length(innerXTadShapeInfo);
|
||||||
|
|
||||||
const uint tadEws = shape::elementWiseStride(tadShapeInfo);
|
Nd4jLong* zOffsets = nullptr;
|
||||||
const uint zEws = shape::elementWiseStride(zShapeInfo);
|
ALLOCATE(zOffsets, workspace, zLen, Nd4jLong);
|
||||||
|
shape::calcOffsets(zShapeInfo, zOffsets);
|
||||||
|
|
||||||
const Nd4jLong* tadShape = shape::shapeOf(tadShapeInfo);
|
Nd4jLong* outerXTadOffsets = zOffsets;
|
||||||
const Nd4jLong* tadStride = shape::stride(tadShapeInfo);
|
if(!sameOffsets1) {
|
||||||
|
ALLOCATE(outerXTadOffsets, workspace, zLen, Nd4jLong);
|
||||||
|
shape::calcOffsets(outerXTadShapeInfo, outerXTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen);
|
Nd4jLong* innerXTadOffsets = zOffsets;
|
||||||
|
if(!sameOffsets2) {
|
||||||
|
ALLOCATE(innerXTadOffsets, workspace, tadLen, Nd4jLong);
|
||||||
|
shape::calcOffsets(innerXTadShapeInfo, innerXTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
switch (kindOfLoop) {
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
const auto tad = x + outerXTadOffsets[i];
|
||||||
|
auto s = OpType::startingValue(tad);
|
||||||
|
|
||||||
|
for (Nd4jLong j = 0; j < tadLen; j++)
|
||||||
|
s = OpType::update(s, OpType::op(tad[innerXTadOffsets[j]], extraParams), extraParams);
|
||||||
|
|
||||||
|
z[zOffsets[i]] = OpType::postProcess(s, tadLen, extraParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
||||||
|
|
||||||
|
RELEASE(outerXTadShapeInfo, workspace);
|
||||||
|
RELEASE(innerXTadShapeInfo, workspace);
|
||||||
|
RELEASE(zOffsets, workspace);
|
||||||
|
if(!sameOffsets1)
|
||||||
|
RELEASE(outerXTadOffsets, workspace);
|
||||||
|
if(!sameOffsets2)
|
||||||
|
RELEASE(innerXTadOffsets, workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z, typename E>
|
||||||
|
template <typename OpType>
|
||||||
|
void sd::ReductionLoops<X, Z, E>::loopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, E* extraParams) {
|
||||||
|
|
||||||
|
const int xRank = shape::rank(xShapeInfo);
|
||||||
|
const int zRank = shape::rank(zShapeInfo);
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
// case LoopKind::SMALLARR2DX: {
|
|
||||||
// shape::printShapeInfoLinear(xShapeInfo);
|
// shape::printShapeInfoLinear(xShapeInfo);
|
||||||
// shape::printShapeInfoLinear(zShapeInfo);
|
// shape::printShapeInfoLinear(zShapeInfo);
|
||||||
// const auto xLen = zLen * tadLen;
|
// shape::printIntArray(dims, shape::rank(xShapeInfo));
|
||||||
// for (uint i = 0; i < xLen; ++i) {
|
|
||||||
// const auto zOffset = shape::subArrayOffset(i, xShapeInfo, zShapeInfo, dimsToExclude, dimsLen);
|
|
||||||
// const uint tadInd = (i / tadEws) % tadLen;
|
|
||||||
// auto startVal = tadInd ? z[zOffset] : static_cast<Z>(OpType::startingValue(x));
|
|
||||||
// z[zOffset] = OpType::update(startVal, OpType::op(x[i], extraParams), extraParams);
|
|
||||||
// if(tadInd == tadLen - 1)
|
|
||||||
// z[zOffset] = OpType::postProcess(z[zOffset], tadLen, extraParams);
|
|
||||||
// printf("%u - %lld\n", i, zOffset);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
case LoopKind::SMALLARR2DX: {
|
|
||||||
const auto uTadLen = static_cast<uint>(tadLen);
|
|
||||||
const auto uZLenMinusOne = static_cast<uint>(zLen - 1);
|
|
||||||
const auto xLen = static_cast<uint>(zLen * uTadLen);
|
|
||||||
const auto sv = static_cast<Z>(OpType::startingValue(x));
|
|
||||||
|
|
||||||
for (uint i = 0; i <= uZLenMinusOne; i++)
|
if(xRank == 2 && zRank == 1)
|
||||||
z[i] = OpType::startingValue(x);
|
reduceExec21<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
|
else if(xRank == 3 && zRank == 1)
|
||||||
uint zOffset = 0;
|
reduceExec31<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
for (uint i = 0; i < xLen; ++i) {
|
else if(xRank == 3 && zRank == 2)
|
||||||
z[zOffset] = OpType::update(z[zOffset], OpType::op(x[i], extraParams), extraParams);
|
reduceExec32<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
zOffset = zOffset == uZLenMinusOne ? 0 : zOffset + 1;
|
else if(xRank == 4 && zRank == 1)
|
||||||
}
|
reduceExec41<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
|
else if(xRank == 4 && zRank == 2)
|
||||||
for (uint i = 0; i <= uZLenMinusOne; i++)
|
reduceExec42<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
z[i] = OpType::postProcess(z[i], tadLen, extraParams);
|
else if(xRank == 4 && zRank == 3)
|
||||||
}
|
reduceExec43<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
break;
|
else if(xRank == 5 && zRank == 1)
|
||||||
|
reduceExec51<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
//*********************************************//
|
else if(xRank == 5 && zRank == 2)
|
||||||
case LoopKind::EWS1: {
|
reduceExec52<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
for (auto i = start; i < stop; i++) {
|
else if(xRank == 5 && zRank == 3)
|
||||||
auto tad = x + tadOffsets[i];
|
reduceExec53<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
auto s = OpType::startingValue(tad);
|
else if(xRank == 5 && zRank == 4)
|
||||||
|
reduceExec54<X,Z,E,OpType>(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
for (Nd4jLong j = 0; j < tadLen; j++)
|
else
|
||||||
s = OpType::update(s, OpType::op(tad[j], extraParams), extraParams);
|
reduceDefault<X,Z,E,OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
|
}
|
||||||
z[i] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::EWSNONZERO: {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong j = 0; j < tadLen; j++)
|
|
||||||
s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams);
|
|
||||||
|
|
||||||
z[i * zEws] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::RANK1: {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong i0 = 0; i0 < tadLen; ++i0)
|
|
||||||
s = OpType::update(s, OpType::op(tad[i0 * tadStride[0]], extraParams), extraParams);
|
|
||||||
|
|
||||||
z[i] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::RANK2: {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
|
|
||||||
for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
|
|
||||||
s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1]], extraParams), extraParams);
|
|
||||||
|
|
||||||
z[i] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::RANK3: {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
|
|
||||||
for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
|
|
||||||
for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2)
|
|
||||||
s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]], extraParams), extraParams);
|
|
||||||
|
|
||||||
z[i] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::RANK4: {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
|
|
||||||
for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
|
|
||||||
for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2)
|
|
||||||
for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3)
|
|
||||||
s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3]], extraParams), extraParams);
|
|
||||||
|
|
||||||
z[i] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::RANK5: {
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
|
|
||||||
for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
|
|
||||||
for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2)
|
|
||||||
for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3)
|
|
||||||
for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4)
|
|
||||||
s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4]], extraParams), extraParams);
|
|
||||||
|
|
||||||
z[i] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::X_EWSNONZERO: {
|
|
||||||
uint castZShapeInfo[MAX_RANK];
|
|
||||||
const bool canCastZ = sd::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong j = 0; j < tadLen; j++)
|
|
||||||
s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams);
|
|
||||||
|
|
||||||
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
|
|
||||||
z[zOffset] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
case LoopKind::Z_EWSNONZERO: {
|
|
||||||
uint castTadShapeInfo[MAX_RANK];
|
|
||||||
const bool canCastTad = sd::DataTypeUtils::castShapeInfo<uint>(tadShapeInfo, castTadShapeInfo);
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong j = 0; j < tadLen; j++) {
|
|
||||||
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
|
|
||||||
s = OpType::update(s, OpType::op(tad[tadOffset], extraParams), extraParams);
|
|
||||||
}
|
|
||||||
|
|
||||||
z[i * zEws] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
//*********************************************//
|
|
||||||
default: {
|
|
||||||
auto innertadOffsets = new Nd4jLong[tadLen];
|
|
||||||
shape::calcOffsets(tadShapeInfo, innertadOffsets);
|
|
||||||
|
|
||||||
uint castZShapeInfo[MAX_RANK];
|
|
||||||
const bool canCastZ = sd::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, castZShapeInfo);
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; i++) {
|
|
||||||
auto tad = x + tadOffsets[i];
|
|
||||||
auto s = OpType::startingValue(tad);
|
|
||||||
|
|
||||||
for (Nd4jLong j = 0; j < tadLen; j++)
|
|
||||||
s = OpType::update(s, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams);
|
|
||||||
|
|
||||||
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
|
|
||||||
z[zOffset] = OpType::postProcess(s, tadLen, extraParams);
|
|
||||||
};
|
|
||||||
|
|
||||||
delete[] innertadOffsets;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -52,11 +52,9 @@ namespace sd {
|
||||||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
|
* allocates memory for sub-array shapeInfo and copy shape and strides at axes(positions) stored in dims
|
||||||
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
|
|
||||||
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
|
||||||
*/
|
*/
|
||||||
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,12 @@ namespace sd {
|
||||||
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
|
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
|
||||||
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
|
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
|
||||||
|
// for example
|
||||||
|
// if rank = 3 and dimsToExclude = {0,2} then output = {1,0,2}, if rank = 3 and dimsToExclude = {2} then output = {0,1,2}
|
||||||
|
// if rank = 3 and dimsToExclude = {0} then output = {1,2,0}, if rank = 4 and dimsToExclude = {0,3} then output = {1,2,0,3}
|
||||||
|
static std::vector<int> evalDimsForReduceOp(const int rank, const std::vector<int>& dimsToExclude);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* evaluate output shape for reduce operation when input shape is empty
|
* evaluate output shape for reduce operation when input shape is empty
|
||||||
* behavior is analogous to tf
|
* behavior is analogous to tf
|
||||||
|
|
|
@ -94,9 +94,9 @@ namespace sd {
|
||||||
auto tadOffsets = Environment::getInstance().isCPU() ? pack.primaryOffsets() : pack.specialOffsets();
|
auto tadOffsets = Environment::getInstance().isCPU() ? pack.primaryOffsets() : pack.specialOffsets();
|
||||||
|
|
||||||
if (_opType == 0)
|
if (_opType == 0)
|
||||||
NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets);
|
NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size());
|
||||||
else
|
else
|
||||||
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets);
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
|
|
|
@ -184,6 +184,43 @@ ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast
|
||||||
|
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
} // namespace sd
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(const Nd4jLong* inShapeInfo, const std::vector<int> &dimsWithUnities, sd::memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
Nd4jLong* newShapeInfo = nullptr;
|
||||||
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(inShapeInfo) - dimsWithUnities.size()), Nd4jLong);
|
||||||
|
|
||||||
|
int temp;
|
||||||
|
if(dimsWithUnities.size() == 1 && shape::isCommonVector(inShapeInfo, temp) && temp == dimsWithUnities[0]) {
|
||||||
|
auto dims = ShapeUtils::evalDimsToExclude(shape::rank(inShapeInfo), {temp});
|
||||||
|
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dims.data(), dims.size(), newShapeInfo);
|
||||||
|
} else {
|
||||||
|
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsWithUnities.data(), dimsWithUnities.size(), newShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(newShapeInfo);
|
||||||
|
|
||||||
|
RELEASE(newShapeInfo, workspace);
|
||||||
|
|
||||||
|
return bufferForShapeInfo(descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
ConstantShapeBuffer& ConstantShapeHelper::createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, sd::memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
Nd4jLong* newShapeInfo = ShapeBuilders::createSubArrShapeInfo(inShapeInfo, dims, dimsSize, workspace);
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(newShapeInfo);
|
||||||
|
|
||||||
|
RELEASE(newShapeInfo, workspace);
|
||||||
|
|
||||||
|
return bufferForShapeInfo(descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -443,17 +443,17 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
|
||||||
// calculate index of current batch
|
// calculate index of current batch
|
||||||
Nd4jLong batchInd;
|
Nd4jLong batchInd;
|
||||||
if(cRank > 2)
|
if(cRank > 2)
|
||||||
batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, cBatchDims);
|
batchInd = shape::coords2index(cShapeInfo, cBatchDims, cRank - 2, cCoords.data());
|
||||||
|
|
||||||
// evaluate A coordinates
|
// evaluate A coordinates
|
||||||
if(aRank > 2)
|
if(aRank > 2)
|
||||||
shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, aBatchDims);
|
shape::index2coords(batchInd, aShapeInfo, aBatchDims, aRank - 2, aCoords.data());
|
||||||
aCoords[aMaxis] = cCoords[cMaxis];
|
aCoords[aMaxis] = cCoords[cMaxis];
|
||||||
aCoords[aKaxis] = 0;
|
aCoords[aKaxis] = 0;
|
||||||
|
|
||||||
// evaluate B coordinates
|
// evaluate B coordinates
|
||||||
if(bRank > 2)
|
if(bRank > 2)
|
||||||
shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, bBatchDims);
|
shape::index2coords(batchInd, bShapeInfo, bBatchDims, bRank - 2, bCoords.data());
|
||||||
bCoords[bKaxis] = 0;
|
bCoords[bKaxis] = 0;
|
||||||
bCoords[bNaxis] = cCoords[cNaxis];
|
bCoords[bNaxis] = cCoords[cNaxis];
|
||||||
|
|
||||||
|
|
|
@ -26,20 +26,19 @@ namespace sd {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionBoolLoops<X, Z>::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
|
void ReductionBoolLoops<X, Z>::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const int* dims, X* extraParams) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
ReductionLoops<X,Z,X>::template loopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionBoolLoops<X, Y>::wrapper(const int opNum,
|
void ReductionBoolLoops<X, Y>::wrapper(const int opNum, sd::memory::Workspace* workspace,
|
||||||
const X *x, const Nd4jLong *xShapeInfo,
|
const X *x, const Nd4jLong *xShapeInfo,
|
||||||
Y *z, const Nd4jLong *zShapeInfo,
|
Y *z, const Nd4jLong *zShapeInfo,
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
|
const int *dims, X *extraParams) {
|
||||||
X *extraParams, int64_t start, int64_t stop) {
|
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_BOOL_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,20 +28,19 @@ namespace sd {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionFloatLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
|
void ReductionFloatLoops<X, Z>::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const int* dims, Z* extraParams) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
ReductionLoops<X,Z,Z>::template loopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo,
|
void ReductionFloatLoops<X, Y>::wrapper(const int opNum, sd::memory::Workspace* workspace,
|
||||||
|
const X *x, const Nd4jLong *xShapeInfo,
|
||||||
Y *z, const Nd4jLong *zShapeInfo,
|
Y *z, const Nd4jLong *zShapeInfo,
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
|
const int *dims, Y *extraParams) {
|
||||||
Y *extraParams,
|
|
||||||
int64_t start, int64_t stop) {
|
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_FLOAT_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,18 +33,19 @@ namespace sd {
|
||||||
|
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionLongLoops<X, Z>::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
|
void ReductionLongLoops<X, Z>::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const int* dims, X* extraParams) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
ReductionLoops<X,Z,X>::template loopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
void ReductionLongLoops<X, Y>::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
|
void ReductionLongLoops<X, Y>::wrapper(const int opNum, sd::memory::Workspace* workspace,
|
||||||
const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
|
const X *x, const Nd4jLong *xShapeInfo,
|
||||||
const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) {
|
Y *z, const Nd4jLong *zShapeInfo,
|
||||||
|
const int *dims, X *extraParams) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS);
|
DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_LONG_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,23 +26,23 @@ namespace sd {
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void ReductionSameLoops<X>::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
|
void ReductionSameLoops<X>::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const int* dims, X* extraParams) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
ReductionLoops<X,X,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
|
ReductionLoops<X,X,X>::template loopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename X>
|
template<typename X>
|
||||||
void ReductionSameLoops<X>::wrapper(const int opNum, const X *vx, const Nd4jLong *xShapeInfo, X *vz,
|
void ReductionSameLoops<X>::wrapper(const int opNum, sd::memory::Workspace* workspace,
|
||||||
const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
|
const X *vx, const Nd4jLong *xShapeInfo,
|
||||||
const Nd4jLong *tadOffsets,
|
X *z, const Nd4jLong *zShapeInfo,
|
||||||
X *vextraParams, int64_t start, int64_t stop) {
|
const int *dims, X *vextraParams) {
|
||||||
#ifndef INLINE_LOOPS
|
#ifndef INLINE_LOOPS
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<X *>(vz);
|
auto z = reinterpret_cast<X *>(vz);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_SAME_OPS);
|
DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_SAME_OPS);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <helpers/ShapeBuilders.h>
|
#include <helpers/ShapeBuilders.h>
|
||||||
#include <execution/AffinityManager.h>
|
#include <execution/AffinityManager.h>
|
||||||
#include <helpers/ConstantHelper.h>
|
#include <helpers/ConstantHelper.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <array/PrimaryPointerDeallocator.h>
|
#include <array/PrimaryPointerDeallocator.h>
|
||||||
#include <array/CudaPointerDeallocator.h>
|
#include <array/CudaPointerDeallocator.h>
|
||||||
|
|
||||||
|
@ -187,4 +188,38 @@ ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcas
|
||||||
return bufferForShapeInfo(descriptor);
|
return bufferForShapeInfo(descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(const Nd4jLong* inShapeInfo, const std::vector<int> &dimsWithUnities, sd::memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
Nd4jLong* newShapeInfo = nullptr;
|
||||||
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(inShapeInfo) - dimsWithUnities.size()), Nd4jLong);
|
||||||
|
|
||||||
|
int temp;
|
||||||
|
if(dimsWithUnities.size() == 1 && shape::isCommonVector(inShapeInfo, temp) && temp == dimsWithUnities[0]) {
|
||||||
|
auto dims = ShapeUtils::evalDimsToExclude(shape::rank(inShapeInfo), {temp});
|
||||||
|
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dims.data(), dims.size(), newShapeInfo);
|
||||||
|
} else {
|
||||||
|
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsWithUnities.data(), dimsWithUnities.size(), newShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(newShapeInfo);
|
||||||
|
|
||||||
|
RELEASE(newShapeInfo, workspace);
|
||||||
|
|
||||||
|
return bufferForShapeInfo(descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
ConstantShapeBuffer& ConstantShapeHelper::createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, sd::memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
Nd4jLong* newShapeInfo = ShapeBuilders::createSubArrShapeInfo(inShapeInfo, dims, dimsSize, workspace);
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(newShapeInfo);
|
||||||
|
|
||||||
|
RELEASE(newShapeInfo, workspace);
|
||||||
|
|
||||||
|
return bufferForShapeInfo(descriptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -571,17 +571,17 @@ static __global__ void batchedCudaGemm(const void* vA, const Nd4jLong* aShapeInf
|
||||||
// calculate index of current batch
|
// calculate index of current batch
|
||||||
Nd4jLong batchInd;
|
Nd4jLong batchInd;
|
||||||
if(cBatchDims != nullptr)
|
if(cBatchDims != nullptr)
|
||||||
batchInd = shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims);
|
batchInd = shape::coords2index(cShapeInfo, cBatchDims, cRank - 2, cCoords);
|
||||||
|
|
||||||
// evaluate A coordinates
|
// evaluate A coordinates
|
||||||
if(aBatchDims != nullptr)
|
if(aBatchDims != nullptr)
|
||||||
shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims);
|
shape::index2coords(batchInd, aShapeInfo, aBatchDims, aRank - 2, aCoords);
|
||||||
aCoords[aMaxis] = cCoords[cMaxis];
|
aCoords[aMaxis] = cCoords[cMaxis];
|
||||||
aCoords[aKaxis] = 0;
|
aCoords[aKaxis] = 0;
|
||||||
|
|
||||||
// evaluate B coordinates
|
// evaluate B coordinates
|
||||||
if(bBatchDims != nullptr)
|
if(bBatchDims != nullptr)
|
||||||
shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims);
|
shape::index2coords(batchInd, bShapeInfo, bBatchDims, bRank - 2, bCoords);
|
||||||
bCoords[bKaxis] = 0;
|
bCoords[bKaxis] = 0;
|
||||||
bCoords[bNaxis] = cCoords[cNaxis];
|
bCoords[bNaxis] = cCoords[cNaxis];
|
||||||
|
|
||||||
|
|
|
@ -140,14 +140,26 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
|
Nd4jLong* ShapeBuilders::createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, memory::Workspace* workspace) {
|
||||||
|
|
||||||
Nd4jLong *outShapeInfo = nullptr;
|
Nd4jLong *subArrShapeInfo = nullptr;
|
||||||
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
|
ALLOCATE(subArrShapeInfo, workspace, shape::shapeInfoLength(dimsSize), Nd4jLong);
|
||||||
|
|
||||||
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
|
subArrShapeInfo[0] = dimsSize; // rank
|
||||||
|
sd::ArrayOptions::copyDataType(subArrShapeInfo, inShapeInfo); // type
|
||||||
|
subArrShapeInfo[2*dimsSize + 3] = shape::order(inShapeInfo); // order
|
||||||
|
|
||||||
return outShapeInfo;
|
Nd4jLong* shape = shape::shapeOf(subArrShapeInfo);
|
||||||
|
Nd4jLong* strides = shape::stride(subArrShapeInfo);
|
||||||
|
|
||||||
|
for(int i = 0; i < dimsSize; ++i) {
|
||||||
|
shape[i] = shape::sizeAt(inShapeInfo, dims[i]);
|
||||||
|
strides[i] = shape::strideAt(inShapeInfo, dims[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
shape::checkStridesEwsAndOrder(subArrShapeInfo);
|
||||||
|
|
||||||
|
return subArrShapeInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -1062,6 +1062,17 @@ bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, const std::vector<Nd4
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
std::vector<int> ShapeUtils::evalDimsForReduceOp(const int rank, const std::vector<int>& dimsToExclude) {
|
||||||
|
|
||||||
|
std::vector<int> output = ShapeUtils::evalDimsToExclude(rank, dimsToExclude);
|
||||||
|
|
||||||
|
for(uint j = 0; j < dimsToExclude.size(); ++j)
|
||||||
|
output.emplace_back(dimsToExclude[j]);
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
/*
|
/*
|
||||||
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
||||||
|
|
|
@ -901,6 +901,16 @@ namespace shape {
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *coords, Nd4jLong baseOffset = 0);
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *coords, Nd4jLong baseOffset = 0);
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset = 0);
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset = 0);
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset = 0);
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset = 0);
|
||||||
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, const int* dims); // length of dims is equal to rank of shapeInfo
|
||||||
|
|
||||||
|
// all three arrays should have same rank
|
||||||
|
// all three arrays should have same dimensions or some of them are 1 (that is satisfy broadcasting principle), strides may be different
|
||||||
|
// shapeInfo1 - first array should have max length compared to rest of two arrays
|
||||||
|
ND4J_EXPORT _CUDA_HD void getOffsetBroadcast(const Nd4jLong& startInd, const Nd4jLong ind,
|
||||||
|
const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2, const Nd4jLong* shapeInfo3,
|
||||||
|
const bool sameOffsets12, const bool sameOffsets13,
|
||||||
|
int* coords,
|
||||||
|
Nd4jLong& offset1, Nd4jLong& offset2, Nd4jLong& offset3);
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
|
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
|
||||||
|
|
||||||
|
@ -918,11 +928,12 @@ namespace shape {
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
|
ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
|
||||||
ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords);
|
ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords);
|
||||||
|
// ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, const int* dims, Nd4jLong *coords);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims);
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, int *coords);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
||||||
|
@ -935,7 +946,7 @@ namespace shape {
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims);
|
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int* dims, const int dimsSize, const int *coords);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* increment n-dimensional array by one iteration by changing coord appropriately
|
* increment n-dimensional array by one iteration by changing coord appropriately
|
||||||
|
@ -951,7 +962,7 @@ namespace shape {
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo);
|
ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo);
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned);
|
ND4J_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned);
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo);
|
ND4J_EXPORT _CUDA_HD void printShapeInfo(const Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo);
|
ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
|
@ -1057,10 +1068,10 @@ namespace shape {
|
||||||
ND4J_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities);
|
ND4J_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {1,3}, dimsSize = 2
|
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude(points on unity dimensions) = {1,3}, dimsSize = 2
|
||||||
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
||||||
*/
|
*/
|
||||||
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo);
|
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int* dimsToExclude, const int dimsSize, Nd4jLong* outShapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||||
|
@ -1847,13 +1858,13 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape,
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims) {
|
INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, const int *coords) {
|
||||||
|
|
||||||
Nd4jLong index, shift = 1;;
|
Nd4jLong index, shift = 1;;
|
||||||
|
|
||||||
index = coords[tadDims[dimsSize - 1]];
|
index = coords[dims[dimsLen - 1]];
|
||||||
for(uint i = dimsSize - 1; i >= 1; --i) {
|
for(uint i = dimsLen - 1; i >= 1; --i) {
|
||||||
shift *= shapeInfo[tadDims[i]];
|
shift *= shapeInfo[dims[i]];
|
||||||
index += shift * coords[i - 1];
|
index += shift * coords[i - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3324,6 +3335,18 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong
|
||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) {
|
||||||
|
|
||||||
|
Nd4jLong offset = baseOffset;
|
||||||
|
|
||||||
|
for(uint i = 1; i <= shapeInfo[0]; ++i)
|
||||||
|
if(shapeInfo[i] != 1)
|
||||||
|
offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) {
|
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) {
|
||||||
|
|
||||||
|
@ -3337,17 +3360,78 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coor
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) {
|
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, const int* dims) {
|
||||||
|
|
||||||
Nd4jLong offset = baseOffset;
|
Nd4jLong offset = 0;
|
||||||
|
|
||||||
for(uint i = 1; i <= shapeInfo[0]; ++i)
|
for(uint i = 1; i <= shapeInfo[0]; ++i)
|
||||||
if(shapeInfo[i] != 1)
|
if(shapeInfo[i] != 1)
|
||||||
offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
|
offset += coords[dims[i - 1]] * shapeInfo[shapeInfo[0] + i];
|
||||||
|
|
||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD void getOffsetBroadcast(const Nd4jLong& startInd, const Nd4jLong ind,
|
||||||
|
const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2, const Nd4jLong* shapeInfo3,
|
||||||
|
const bool sameOffsets12, const bool sameOffsets13,
|
||||||
|
int* coords,
|
||||||
|
Nd4jLong& offset1, Nd4jLong& offset2, Nd4jLong& offset3) {
|
||||||
|
|
||||||
|
const Nd4jLong* shape1 = shape::shapeOf(shapeInfo1);
|
||||||
|
const Nd4jLong* strides1 = shape::stride(shapeInfo1);
|
||||||
|
const Nd4jLong* shape2 = shape::shapeOf(shapeInfo2);
|
||||||
|
const Nd4jLong* strides2 = shape::stride(shapeInfo2);
|
||||||
|
const Nd4jLong* shape3 = shape::shapeOf(shapeInfo3);
|
||||||
|
const Nd4jLong* strides3 = shape::stride(shapeInfo3);
|
||||||
|
|
||||||
|
if(startInd == ind) {
|
||||||
|
|
||||||
|
if(shape::rank(shapeInfo1) == 0) {
|
||||||
|
offset1 = offset2 = offset3 = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
shape::index2coords(ind, shapeInfo1, coords);
|
||||||
|
offset1 = shape::getOffset(shapeInfo1, coords);
|
||||||
|
|
||||||
|
if(sameOffsets12)
|
||||||
|
offset2 = offset1;
|
||||||
|
else
|
||||||
|
offset2 = shape::getOffset(shapeInfo2, coords);
|
||||||
|
|
||||||
|
if(sameOffsets13)
|
||||||
|
offset3 = offset1;
|
||||||
|
else
|
||||||
|
offset3 = shape::getOffset(shapeInfo3, coords);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int axis = shapeInfo1[0] - 1;
|
||||||
|
while(coords[axis] == shape1[axis] - 1) {
|
||||||
|
if(!sameOffsets12 && shape2[axis] != 1)
|
||||||
|
offset2 -= (shape2[axis] - 1) * strides2[axis];
|
||||||
|
if(!sameOffsets13 && shape3[axis] != 1)
|
||||||
|
offset3 -= (shape3[axis] - 1) * strides3[axis];
|
||||||
|
if(shape1[axis] != 1)
|
||||||
|
offset1 -= (shape1[axis] - 1) * strides1[axis];
|
||||||
|
coords[axis--] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
++coords[axis];
|
||||||
|
offset1 += strides1[axis];
|
||||||
|
|
||||||
|
if(!sameOffsets12 && shape2[axis] != 1)
|
||||||
|
offset2 += strides2[axis];
|
||||||
|
if(!sameOffsets13 && shape3[axis] != 1)
|
||||||
|
offset3 += strides3[axis];
|
||||||
|
|
||||||
|
if(sameOffsets12)
|
||||||
|
offset2 = offset1;
|
||||||
|
if(sameOffsets13)
|
||||||
|
offset3 = offset1;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the tensor along dimension
|
* Returns the tensor along dimension
|
||||||
|
@ -3443,7 +3527,7 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
INLINEDEF _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo) {
|
INLINEDEF _CUDA_HD void printShapeInfo(const Nd4jLong *shapeInfo) {
|
||||||
int rank = shape::rank(shapeInfo);
|
int rank = shape::rank(shapeInfo);
|
||||||
Nd4jLong *shape = shape::shapeOf(shapeInfo);
|
Nd4jLong *shape = shape::shapeOf(shapeInfo);
|
||||||
printf("Rank %d\n",rank);
|
printf("Rank %d\n",rank);
|
||||||
|
@ -4583,89 +4667,92 @@ INLINEDEF void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const c
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order) {
|
INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order) {
|
||||||
|
|
||||||
// if(false) { // tests showed that this code did calculation notably slower even for big N
|
const uint64_t len = shape::prodLong(shape, rank);
|
||||||
// Nd4jLong indexes[MAX_RANK];
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(private(indexes))
|
|
||||||
// for (Nd4jLong i = 0; i < N; ++i) {
|
|
||||||
// shape::index2coords(rank, shape, i, indexes);
|
|
||||||
// subArrOffsets[i] = 0;
|
|
||||||
// for (int j = 0; j < rank; ++j)
|
|
||||||
// if(shape[j] != 1)
|
|
||||||
// subArrOffsets[i] += indexes[j] * strides[j];
|
|
||||||
// }
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// set offset for first sub-array, it is equal to zero always
|
// set offset for first sub-array, it is equal to zero always
|
||||||
offsets[0] = 0;
|
offsets[0] = 0;
|
||||||
|
|
||||||
Nd4jLong * idx = new Nd4jLong[rank];
|
uint coords[MAX_RANK];
|
||||||
Nd4jLong* offsetPerDim = new Nd4jLong[rank];
|
memset(coords, 0, sizeof(uint) * rank);
|
||||||
memset(idx, 0, sizeof(Nd4jLong) * rank);
|
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
|
||||||
for (int k = 0; k < rank; ++k)
|
|
||||||
offsetPerDim[k] = (shape[k] - 1) * strides[k];
|
|
||||||
|
|
||||||
Nd4jLong init = 0, i = 1;
|
|
||||||
// nested loops - calculation of sub-array offsets
|
|
||||||
if(order == 'c') {
|
if(order == 'c') {
|
||||||
|
|
||||||
Nd4jLong rankMinusOne = rank - 1, j = rankMinusOne;
|
for (uint64_t i = 1; i < len; ++i) {
|
||||||
|
int axis = rank - 1;
|
||||||
|
offsets[i] = 0;
|
||||||
|
while(coords[axis] == shape[axis] - 1) {
|
||||||
|
offsets[i] -= (shape[axis] - 1) * strides[axis];
|
||||||
|
coords[axis--] = 0;
|
||||||
|
}
|
||||||
|
++coords[axis];
|
||||||
|
offsets[i] += offsets[i-1] + strides[axis];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
while(j >= 0) {
|
for (uint64_t i = 1; i < len; ++i) {
|
||||||
|
int axis = 0;
|
||||||
if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity
|
offsets[i] = 0;
|
||||||
|
while(coords[axis] == shape[axis] - 1) {
|
||||||
if(j == rankMinusOne) { // last dimension
|
offsets[i] -= (shape[axis] - 1) * strides[axis];
|
||||||
for(int l = 1; l < shape[j]; ++l) {
|
coords[axis++] = 0;
|
||||||
offsets[i] = offsets[i - 1] + strides[j];
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
--j;
|
|
||||||
}
|
|
||||||
else if(idx[j] < shape[j] - 1) {
|
|
||||||
init += strides[j];
|
|
||||||
offsets[i++] = init;
|
|
||||||
++idx[j];
|
|
||||||
j = rankMinusOne;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
init -= offsetPerDim[j];
|
|
||||||
idx[j--] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
Nd4jLong j = 0;
|
|
||||||
|
|
||||||
while(j < rank) {
|
|
||||||
|
|
||||||
if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity
|
|
||||||
|
|
||||||
if(j == 0) { // last dimension
|
|
||||||
for(int l = 1; l < shape[j]; ++l) {
|
|
||||||
offsets[i] = offsets[i - 1] + strides[j];
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
++j;
|
|
||||||
}
|
|
||||||
else if(idx[j] < shape[j] - 1) {
|
|
||||||
init += strides[j];
|
|
||||||
offsets[i++] = init;
|
|
||||||
++idx[j];
|
|
||||||
j = 0;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
init -= offsetPerDim[j];
|
|
||||||
idx[j++] = 0;
|
|
||||||
}
|
}
|
||||||
|
++coords[axis];
|
||||||
|
offsets[i] += offsets[i-1] + strides[axis];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
delete []idx;
|
// Nd4jLong init = 0, i = 1;
|
||||||
delete []offsetPerDim;
|
// // nested loops - calculation of sub-array offsets
|
||||||
|
// if(order == 'c') {
|
||||||
|
|
||||||
|
// int rankMinusOne = rank - 1, j = rankMinusOne;
|
||||||
|
|
||||||
|
// while(j >= 0) {
|
||||||
|
|
||||||
|
// if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity
|
||||||
|
|
||||||
|
// if(j == rankMinusOne) { // last dimension
|
||||||
|
// for(uint l = 1; l < shape[j]; ++l)
|
||||||
|
// offsets[i++] = offsets[i - 1] + strides[j];
|
||||||
|
// --j;
|
||||||
|
// }
|
||||||
|
// else if(coords[j] < shape[j] - 1) {
|
||||||
|
// init += strides[j];
|
||||||
|
// offsets[i++] = init;
|
||||||
|
// ++coords[j];
|
||||||
|
// j = rankMinusOne;
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// init -= (shape[j] - 1) * strides[j];
|
||||||
|
// coords[j--] = 0;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
|
||||||
|
// int j = 0;
|
||||||
|
|
||||||
|
// while(j < rank) {
|
||||||
|
|
||||||
|
// if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity
|
||||||
|
|
||||||
|
// if(j == 0) { // last dimension
|
||||||
|
// for(uint l = 1; l < shape[j]; ++l)
|
||||||
|
// offsets[i++] = offsets[i - 1] + strides[j];
|
||||||
|
// ++j;
|
||||||
|
// }
|
||||||
|
// else if(coords[j] < shape[j] - 1) {
|
||||||
|
// init += strides[j];
|
||||||
|
// offsets[i++] = init;
|
||||||
|
// ++coords[j];
|
||||||
|
// j = 0;
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// init -= (shape[j] - 1) * strides[j];
|
||||||
|
// coords[j++] = 0;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -4884,13 +4971,14 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jL
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims) {
|
INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, int *coords) {
|
||||||
|
|
||||||
for(uint i = dimsSize - 1; i > 0; --i) {
|
for(uint i = dimsLen - 1; i > 0; --i) {
|
||||||
coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]];
|
const auto ind = dims[i];
|
||||||
index /= shapeInfo[1 + tadDims[i]];
|
coords[ind] = index % shapeInfo[1 + ind];
|
||||||
|
index /= shapeInfo[1 + ind];
|
||||||
}
|
}
|
||||||
coords[tadDims[0]] = index; // last iteration
|
coords[dims[0]] = index; // last iteration
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -4921,6 +5009,64 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) {
|
||||||
|
|
||||||
|
const int rank = shape::rank(inShapeInfo);
|
||||||
|
const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo));
|
||||||
|
|
||||||
|
if(numOfNonUnities == rank) { // no unities in shape, no copy procedure
|
||||||
|
shapeNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1;
|
||||||
|
stridesNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1 + rank;
|
||||||
|
return numOfNonUnities;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(uint j = 0, i = 0; i < rank; ++i) {
|
||||||
|
if(shape::shapeOf(inShapeInfo)[i] != 1) {
|
||||||
|
shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i];
|
||||||
|
shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stridesNoUnities = shapeNoUnities + numOfNonUnities;
|
||||||
|
|
||||||
|
return numOfNonUnities;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int* dimsToExclude, const int dimsSize, Nd4jLong* outShapeInfo) {
|
||||||
|
|
||||||
|
outShapeInfo[0] = inShapeInfo[0] - dimsSize;
|
||||||
|
|
||||||
|
for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) {
|
||||||
|
if(j < dimsSize && i == dimsToExclude[j]) {
|
||||||
|
++j;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i];
|
||||||
|
shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type
|
||||||
|
*shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
|
||||||
|
outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
// INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, int *coords) {
|
||||||
|
|
||||||
|
// if(startIndex == index) {
|
||||||
|
// shape::index2coords(index, shapeInfo, dims, dimsLen, coords);
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// int i = dimsLen - 1;
|
||||||
|
// while(coords[dims[i]] == shape::sizeAt(shapeInfo, dims[i]) - 1)
|
||||||
|
// coords[dims[i--]] = 0;
|
||||||
|
// ++coords[dims[i]];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||||
|
|
||||||
|
@ -5111,50 +5257,6 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) {
|
|
||||||
|
|
||||||
const int rank = shape::rank(inShapeInfo);
|
|
||||||
const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo));
|
|
||||||
|
|
||||||
if(numOfNonUnities == rank) { // no unities in shape, no copy procedure
|
|
||||||
shapeNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1;
|
|
||||||
stridesNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1 + rank;
|
|
||||||
return numOfNonUnities;
|
|
||||||
}
|
|
||||||
|
|
||||||
for(uint j = 0, i = 0; i < rank; ++i) {
|
|
||||||
if(shape::shapeOf(inShapeInfo)[i] != 1) {
|
|
||||||
shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i];
|
|
||||||
shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stridesNoUnities = shapeNoUnities + numOfNonUnities;
|
|
||||||
|
|
||||||
return numOfNonUnities;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo) {
|
|
||||||
|
|
||||||
outShapeInfo[0] = inShapeInfo[0] - dimsSize;
|
|
||||||
|
|
||||||
for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) {
|
|
||||||
if(j < dimsSize && i == dimsToExclude[j]) {
|
|
||||||
++j;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i];
|
|
||||||
shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type
|
|
||||||
*shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
|
|
||||||
outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||||
|
|
|
@ -470,8 +470,7 @@ static void execTransformBool(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength);
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
|
|
||||||
|
|
||||||
static void execReduceSame(sd::LaunchContext *lc,
|
static void execReduceSame(sd::LaunchContext *lc,
|
||||||
int opNum,
|
int opNum,
|
||||||
|
@ -480,8 +479,7 @@ static void execTransformBool(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength);
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
|
|
||||||
|
|
||||||
static void execReduceBool(sd::LaunchContext *lc,
|
static void execReduceBool(sd::LaunchContext *lc,
|
||||||
int opNum,
|
int opNum,
|
||||||
|
@ -490,8 +488,7 @@ static void execTransformBool(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength);
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
|
|
||||||
|
|
||||||
static void execReduceLong(sd::LaunchContext *lc,
|
static void execReduceLong(sd::LaunchContext *lc,
|
||||||
int opNum,
|
int opNum,
|
||||||
|
@ -500,8 +497,7 @@ static void execTransformBool(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength);
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
|
|
@ -585,8 +585,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength) {
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -597,13 +596,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
|
||||||
if (shape::isEmpty(hZShapeInfo))
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
|
|
||||||
};
|
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -614,24 +607,16 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength) {
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
|
||||||
|
|
||||||
// nothing to do here if result is empty
|
// nothing to do here if result is empty
|
||||||
if (shape::isEmpty(hZShapeInfo))
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES);
|
|
||||||
};
|
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -642,8 +627,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength) {
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
|
@ -653,13 +637,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
||||||
if (shape::isEmpty(hZShapeInfo))
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
|
||||||
};
|
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -670,8 +648,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, const Nd4jLong *hZShapeInfo,
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, const Nd4jLong *dZShapeInfo,
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength) {
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
|
|
||||||
|
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
|
@ -681,13 +658,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
|
||||||
if (shape::isEmpty(hZShapeInfo))
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES, LONG_TYPES);
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, LONG_TYPES);
|
|
||||||
};
|
|
||||||
|
|
||||||
const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
|
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -447,28 +447,26 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
|
||||||
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
||||||
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
||||||
try {
|
try {
|
||||||
|
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
|
|
||||||
auto hTADShapeInfo = tadPackX.primaryShapeInfo();
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
auto hTADOffsets = tadPackX.primaryOffsets();
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
const Nd4jLong* zShapeInfoD = dZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
|
NativeOpExecutioner::execReduceFloat(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
|
||||||
|
|
||||||
NativeOpExecutioner::execReduceFloat(nullptr, opNum,
|
|
||||||
dbX->primary(),
|
|
||||||
hXShapeInfo,
|
|
||||||
dbX->special(),
|
|
||||||
dXShapeInfo,
|
|
||||||
extraParams,
|
|
||||||
dbZ->primary(),
|
|
||||||
hZShapeInfo,
|
|
||||||
dbZ->special(),
|
|
||||||
dZShapeInfo,
|
|
||||||
dimension,
|
|
||||||
dimensionLength,
|
|
||||||
hTADShapeInfo,
|
|
||||||
hTADOffsets);
|
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
@ -481,30 +479,27 @@ void execReduceBool2(Nd4jPointer *extraPointers,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
|
||||||
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
auto dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
dimensionLength);
|
|
||||||
|
|
||||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
auto hTADOffsets = tadPack.primaryOffsets();
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
const Nd4jLong* zShapeInfoD = dZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo)) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
|
NativeOpExecutioner::execReduceBool(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
|
||||||
|
|
||||||
NativeOpExecutioner::execReduceBool(nullptr, opNum,
|
|
||||||
dbX->primary(),
|
|
||||||
hXShapeInfo,
|
|
||||||
dbX->special(),
|
|
||||||
dXShapeInfo,
|
|
||||||
extraParams,
|
|
||||||
dbZ->primary(),
|
|
||||||
hZShapeInfo,
|
|
||||||
dbZ->special(),
|
|
||||||
dZShapeInfo,
|
|
||||||
dimension,
|
|
||||||
dimensionLength,
|
|
||||||
hTADShapeInfo,
|
|
||||||
hTADOffsets);
|
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
@ -521,26 +516,22 @@ void execReduceSame2(Nd4jPointer *extraPointers,
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
dimensionLength);
|
|
||||||
|
|
||||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
auto hTADOffsets = tadPack.primaryOffsets();
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
const Nd4jLong* zShapeInfoD = dZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
|
NativeOpExecutioner::execReduceSame(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
|
||||||
|
|
||||||
NativeOpExecutioner::execReduceSame(nullptr, opNum,
|
|
||||||
dbX->primary(),
|
|
||||||
hXShapeInfo,
|
|
||||||
dbX->special(),
|
|
||||||
dXShapeInfo,
|
|
||||||
extraParams,
|
|
||||||
dbZ->primary(),
|
|
||||||
hZShapeInfo,
|
|
||||||
dbZ->special(),
|
|
||||||
dZShapeInfo,
|
|
||||||
dimension,
|
|
||||||
dimensionLength,
|
|
||||||
hTADShapeInfo,
|
|
||||||
hTADOffsets);
|
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
@ -557,25 +548,22 @@ void execReduceLong2(Nd4jPointer *extraPointers,
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
|
|
||||||
auto hTADShapeInfo = tadPack.primaryShapeInfo();
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
auto hTADOffsets = tadPack.primaryOffsets();
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
const Nd4jLong* zShapeInfoD = dZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
zShapeInfoD = reinterpret_cast<Nd4jLong const*>(zPack.special());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
|
NativeOpExecutioner::execReduceLong(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
|
||||||
|
|
||||||
NativeOpExecutioner::execReduceLong(nullptr, opNum,
|
|
||||||
dbX->primary(),
|
|
||||||
hXShapeInfo,
|
|
||||||
dbX->special(),
|
|
||||||
dXShapeInfo,
|
|
||||||
extraParams,
|
|
||||||
dbZ->primary(),
|
|
||||||
hZShapeInfo,
|
|
||||||
dbZ->special(),
|
|
||||||
dZShapeInfo,
|
|
||||||
dimension,
|
|
||||||
dimensionLength,
|
|
||||||
hTADShapeInfo,
|
|
||||||
hTADOffsets);
|
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
|
|
@ -210,7 +210,7 @@ void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
auto reductionPointer = lc->getReductionPointer();
|
auto reductionPointer = lc->getReductionPointer();
|
||||||
|
|
||||||
dim3 launchDims = dim3(256, 256, 32768);
|
dim3 launchDims = dim3(256, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -577,8 +577,7 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong const* hZShapeInfo,
|
void *hZ, Nd4jLong const* hZShapeInfo,
|
||||||
void *dZ, Nd4jLong const* dZShapeInfo,
|
void *dZ, Nd4jLong const* dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength) {
|
||||||
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
|
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
auto reductionPointer = lc->getReductionPointer();
|
auto reductionPointer = lc->getReductionPointer();
|
||||||
|
@ -588,15 +587,14 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
auto xRank = shape::rank(hXShapeInfo);
|
|
||||||
|
|
||||||
if (zType != xType)
|
if (zType != xType)
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execReduceSame requires both X & Z operands to have same type", xType, zType);
|
throw datatype_exception::build("NativeOpExecutioner::execReduceSame requires both X & Z operands to have same type", xType, zType);
|
||||||
|
|
||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 8192);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES);
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
auto res = cudaStreamSynchronize(*stream);
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
@ -612,8 +610,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong const* hZShapeInfo,
|
void *hZ, Nd4jLong const* hZShapeInfo,
|
||||||
void *dZ, Nd4jLong const* dZShapeInfo,
|
void *dZ, Nd4jLong const* dZShapeInfo,
|
||||||
int *dimension,int dimensionLength,
|
int *dimension,int dimensionLength) {
|
||||||
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
|
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
auto reductionPointer = lc->getReductionPointer();
|
auto reductionPointer = lc->getReductionPointer();
|
||||||
|
@ -627,11 +624,10 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
|
||||||
if (zType != sd::DataType::INT64)
|
if (zType != sd::DataType::INT64)
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execReduceLong wrong Z data type", sd::DataType::INT64, zType);
|
throw datatype_exception::build("NativeOpExecutioner::execReduceLong wrong Z data type", sd::DataType::INT64, zType);
|
||||||
|
|
||||||
auto xRank = shape::rank(hXShapeInfo);
|
|
||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES, LONG_TYPES);
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
auto res = cudaStreamSynchronize(*stream);
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
@ -648,8 +644,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong const* hZShapeInfo,
|
void *hZ, Nd4jLong const* hZShapeInfo,
|
||||||
void *dZ, Nd4jLong const* dZShapeInfo,
|
void *dZ, Nd4jLong const* dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength) {
|
||||||
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
|
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
auto reductionPointer = lc->getReductionPointer();
|
auto reductionPointer = lc->getReductionPointer();
|
||||||
|
@ -663,11 +658,10 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
||||||
if (zType != sd::DataType::BOOL)
|
if (zType != sd::DataType::BOOL)
|
||||||
throw std::runtime_error("NativeOpExecutioner::execReduceBool requires Z operand to have BOOL type");
|
throw std::runtime_error("NativeOpExecutioner::execReduceBool requires Z operand to have BOOL type");
|
||||||
|
|
||||||
auto xRank = shape::rank(hXShapeInfo);
|
|
||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
auto res = cudaStreamSynchronize(*stream);
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
@ -675,6 +669,45 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
|
||||||
throw cuda_exception::build("execReduceBool failed", res);
|
throw cuda_exception::build("execReduceBool failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param opNum
|
||||||
|
* @param dX
|
||||||
|
* @param dXShapeInfo
|
||||||
|
* @param extraParams
|
||||||
|
* @param dZ
|
||||||
|
* @param dZShapeInfo
|
||||||
|
*/
|
||||||
|
void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
|
||||||
|
int opNum,
|
||||||
|
const void *hX, const Nd4jLong *hXShapeInfo,
|
||||||
|
const void *dX, const Nd4jLong *dXShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
|
void *hZ, const Nd4jLong *hZShapeInfo,
|
||||||
|
void *dZ, const Nd4jLong *dZShapeInfo,
|
||||||
|
int *dimension, int dimensionLength) {
|
||||||
|
|
||||||
|
auto stream = lc->getCudaStream();
|
||||||
|
auto reductionPointer = lc->getReductionPointer();
|
||||||
|
|
||||||
|
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||||
|
printf("F8 opNum:[%i]\n", opNum);
|
||||||
|
|
||||||
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceFloat failed", res);
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -707,7 +740,8 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
|
auto tadLength = shape::length(hXShapeInfo) / numBlocks;
|
||||||
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, tadLength < CUDA_BLOCK_SIZE ? tadLength : CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32)
|
if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32)
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType);
|
throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType);
|
||||||
|
@ -722,46 +756,6 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
|
||||||
throw cuda_exception::build("execIndexReduce failed", res);
|
throw cuda_exception::build("execIndexReduce failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param opNum
|
|
||||||
* @param dX
|
|
||||||
* @param dXShapeInfo
|
|
||||||
* @param extraParams
|
|
||||||
* @param dZ
|
|
||||||
* @param dZShapeInfo
|
|
||||||
*/
|
|
||||||
void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
|
|
||||||
int opNum,
|
|
||||||
void const* hX, Nd4jLong const* hXShapeInfo,
|
|
||||||
void const* dX, Nd4jLong const* dXShapeInfo,
|
|
||||||
void *extraParams,
|
|
||||||
void *hZ, Nd4jLong const* hZShapeInfo,
|
|
||||||
void *dZ, Nd4jLong const* dZShapeInfo,
|
|
||||||
int *dimension,int dimensionLength,
|
|
||||||
Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
|
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
|
||||||
auto reductionPointer = lc->getReductionPointer();
|
|
||||||
|
|
||||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
|
||||||
printf("F8 opNum:[%i]\n", opNum);
|
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
|
||||||
|
|
||||||
auto xRank = shape::rank(hXShapeInfo);
|
|
||||||
auto numBlocks = shape::length(hZShapeInfo);
|
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
|
||||||
|
|
||||||
// TODO: remove after the release
|
|
||||||
auto res = cudaStreamSynchronize(*stream);
|
|
||||||
if (res != 0)
|
|
||||||
throw cuda_exception::build("execReduceFloat failed", res);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -790,7 +784,7 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
|
||||||
auto xLength = shape::length(hXShapeInfo);
|
auto xLength = shape::length(hXShapeInfo);
|
||||||
auto blockWidth = 256;
|
auto blockWidth = 256;
|
||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
if (sd::Environment::getInstance().isDebugAndVerbose() && launchDims.x == 1)
|
if (sd::Environment::getInstance().isDebugAndVerbose() && launchDims.x == 1)
|
||||||
printf("AF1 opNum:[%i]\n", opNum);
|
printf("AF1 opNum:[%i]\n", opNum);
|
||||||
|
@ -840,7 +834,7 @@ void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
|
||||||
auto xLength = shape::length(hXShapeInfo);
|
auto xLength = shape::length(hXShapeInfo);
|
||||||
auto blockWidth = 256;
|
auto blockWidth = 256;
|
||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, hXShapeInfo, extraParams, dZ,dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, hXShapeInfo, extraParams, dZ,dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
@ -870,9 +864,9 @@ void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
|
||||||
throw std::runtime_error("NativeOpExecutioner::execReduceBoolScalar requires Z operand to have BOOL type");
|
throw std::runtime_error("NativeOpExecutioner::execReduceBoolScalar requires Z operand to have BOOL type");
|
||||||
|
|
||||||
auto xLength = shape::length(hXShapeInfo);
|
auto xLength = shape::length(hXShapeInfo);
|
||||||
auto blockWidth = 256;
|
auto blockWidth = CUDA_BLOCK_SIZE;
|
||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
|
@ -901,9 +895,9 @@ void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execReduceSameScalar requires both X & Z operands to have same type", xType, zType);
|
throw datatype_exception::build("NativeOpExecutioner::execReduceSameScalar requires both X & Z operands to have same type", xType, zType);
|
||||||
|
|
||||||
auto xLength = shape::length(hXShapeInfo);
|
auto xLength = shape::length(hXShapeInfo);
|
||||||
auto blockWidth = 256;
|
auto blockWidth = CUDA_BLOCK_SIZE;
|
||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -932,9 +926,9 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execReduceLongScalar wrong Z data type", sd::DataType::INT64, zType);
|
throw datatype_exception::build("NativeOpExecutioner::execReduceLongScalar wrong Z data type", sd::DataType::INT64, zType);
|
||||||
|
|
||||||
auto xLength = shape::length(hXShapeInfo);
|
auto xLength = shape::length(hXShapeInfo);
|
||||||
auto blockWidth = 256;
|
auto blockWidth = CUDA_BLOCK_SIZE;
|
||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES);
|
||||||
|
|
||||||
|
@ -1128,7 +1122,7 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
auto reductionPointer = lc->getReductionPointer();
|
auto reductionPointer = lc->getReductionPointer();
|
||||||
|
|
||||||
dim3 launchDims = dim3(256, 256, 32768);
|
dim3 launchDims = dim3(256, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -1158,7 +1152,7 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
auto reductionPointer = lc->getReductionPointer();
|
auto reductionPointer = lc->getReductionPointer();
|
||||||
|
|
||||||
dim3 launchDims = dim3(256, 256, 32768);
|
dim3 launchDims = dim3(256, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
@ -1194,9 +1188,9 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
|
||||||
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
auto blockWidth = 256;
|
auto blockWidth = CUDA_BLOCK_SIZE;
|
||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(shape::length(hXShapeInfo), blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(shape::length(hXShapeInfo), blockWidth);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
|
||||||
|
|
||||||
if (xType != yType)
|
if (xType != yType)
|
||||||
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType);
|
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType);
|
||||||
|
@ -1246,7 +1240,7 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
|
||||||
|
|
||||||
|
|
||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum,
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum,
|
||||||
dX, dXShapeInfo,
|
dX, dXShapeInfo,
|
||||||
|
@ -1286,9 +1280,9 @@ void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
|
||||||
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
auto xLength = shape::length(hXShapeInfo);
|
auto xLength = shape::length(hXShapeInfo);
|
||||||
auto blockWidth = 256;
|
auto blockWidth = CUDA_BLOCK_SIZE;
|
||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
|
||||||
|
|
||||||
if (xType != yType)
|
if (xType != yType)
|
||||||
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Y operand to have X type", xType, yType);
|
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Y operand to have X type", xType, yType);
|
||||||
|
@ -1652,7 +1646,7 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
|
||||||
if (sd::Environment::getInstance().isDebugAndVerbose())
|
if (sd::Environment::getInstance().isDebugAndVerbose())
|
||||||
printf("D119 opNum:[%i]\n", opNum);
|
printf("D119 opNum:[%i]\n", opNum);
|
||||||
|
|
||||||
dim3 launchDims(shape::length(hZShapeInfo), 256, 32768);
|
dim3 launchDims(shape::length(hZShapeInfo), CUDA_BLOCK_SIZE / 2, 1024);
|
||||||
|
|
||||||
if (sd::Environment::getInstance().isVerbose() && launchDims.x == 1)
|
if (sd::Environment::getInstance().isVerbose() && launchDims.x == 1)
|
||||||
printf("AD119 opNum:[%i]\n", opNum);
|
printf("AD119 opNum:[%i]\n", opNum);
|
||||||
|
@ -1706,7 +1700,7 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
|
||||||
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Z operand to have floating point data type", zType);
|
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Z operand to have floating point data type", zType);
|
||||||
|
|
||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
|
@ -454,17 +454,24 @@ void execReduceSame2(Nd4jPointer *extraPointers,
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
dimension,
|
|
||||||
shape::length(hDimensionShape));
|
|
||||||
|
|
||||||
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||||
NativeOpExecutioner::execReduceSame(&lc, opNum,
|
NativeOpExecutioner::execReduceSame(&lc, opNum,
|
||||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||||
extraParams,
|
extraParams,
|
||||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
dbZ->primary(), zShapeInfoH, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(zShapeInfoH).special(),
|
||||||
dimension, dimensionLength,
|
dims.data(), dims.size());
|
||||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
|
||||||
|
|
||||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
|
@ -487,17 +494,25 @@ void execReduceLong2(Nd4jPointer *extraPointers,
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
dimension,
|
|
||||||
shape::length(hDimensionShape));
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
|
|
||||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||||
NativeOpExecutioner::execReduceLong(&lc, opNum,
|
NativeOpExecutioner::execReduceLong(&lc, opNum,
|
||||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||||
extraParams,
|
extraParams,
|
||||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
dbZ->primary(), zShapeInfoH, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(zShapeInfoH).special(),
|
||||||
dimension, dimensionLength,
|
dims.data(), dims.size());
|
||||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
|
||||||
|
|
||||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
|
@ -562,17 +577,25 @@ void execReduceBool2(Nd4jPointer *extraPointers,
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
dimension,
|
|
||||||
shape::length(hDimensionShape));
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
|
|
||||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||||
NativeOpExecutioner::execReduceBool(&lc, opNum,
|
NativeOpExecutioner::execReduceBool(&lc, opNum,
|
||||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||||
extraParams,
|
extraParams,
|
||||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
dbZ->primary(), zShapeInfoH, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(zShapeInfoH).special(),
|
||||||
dimension, dimensionLength,
|
dims.data(), dims.size());
|
||||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
|
||||||
|
|
||||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
|
@ -690,17 +713,25 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
|
||||||
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
auto dimension = reinterpret_cast<int *>(dbDimension->primary());
|
||||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
|
const auto zLen = shape::length(hZShapeInfo);
|
||||||
dimension,
|
|
||||||
shape::length(hDimensionShape));
|
std::vector<int> dimensions(dimension, dimension + dimensionLength);
|
||||||
|
|
||||||
|
const Nd4jLong* zShapeInfoH = hZShapeInfo;
|
||||||
|
|
||||||
|
if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
|
||||||
|
auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
|
||||||
|
zShapeInfoH = reinterpret_cast<Nd4jLong const*>(zPack.primary());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector<int>();
|
||||||
|
|
||||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||||
NativeOpExecutioner::execReduceFloat(&lc, opNum,
|
NativeOpExecutioner::execReduceFloat(&lc, opNum,
|
||||||
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
|
||||||
extraParams,
|
extraParams,
|
||||||
dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
|
dbZ->primary(), zShapeInfoH, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(zShapeInfoH).special(),
|
||||||
dimension, dimensionLength,
|
dims.data(), dims.size());
|
||||||
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
|
||||||
|
|
||||||
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
|
|
|
@ -784,24 +784,14 @@ static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, cons
|
||||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
int coords[MAX_RANK];
|
||||||
|
Nd4jLong xOffset, yOffset, zOffset;
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
shape::getOffsetBroadcast(start, i, zShapeInfo, xShapeInfo, yShapeInfo, xzSameOffsets, yzSameOffsets, coords, zOffset, xOffset, yOffset);
|
||||||
|
|
||||||
for (uint j = 0; j < rank; ++j) {
|
|
||||||
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
|
||||||
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
|
||||||
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -665,24 +665,14 @@ static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, cons
|
||||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
int coords[MAX_RANK];
|
||||||
|
Nd4jLong xOffset, yOffset, zOffset;
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
shape::getOffsetBroadcast(start, i, zShapeInfo, xShapeInfo, yShapeInfo, xzSameOffsets, yzSameOffsets, coords, zOffset, xOffset, yOffset);
|
||||||
|
|
||||||
for (uint j = 0; j < rank; ++j) {
|
|
||||||
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
|
||||||
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
|
||||||
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||||
}
|
}
|
||||||
|
|
|
@ -651,24 +651,14 @@ static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, cons
|
||||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
int coords[MAX_RANK];
|
||||||
|
Nd4jLong xOffset, yOffset, zOffset;
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
shape::getOffsetBroadcast(start, i, zShapeInfo, xShapeInfo, yShapeInfo, xzSameOffsets, yzSameOffsets, coords, zOffset, xOffset, yOffset);
|
||||||
|
|
||||||
for (uint j = 0; j < rank; ++j) {
|
|
||||||
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
|
||||||
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
|
||||||
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,71 +114,6 @@ namespace functions {
|
||||||
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_BOOL_OPS);
|
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_BOOL_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void ReduceBoolFunction<X, Y>::exec(const int opNum,
|
|
||||||
const void *x, const Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams,
|
|
||||||
void *z, const Nd4jLong *zShapeInfo,
|
|
||||||
int *dimension, int dimensionLength,
|
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
|
|
||||||
int64_t start, int64_t stop) {
|
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset, start, stop), REDUCE_BOOL_OPS);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
|
||||||
template <typename OpType>
|
|
||||||
void _CUDA_H ReduceBoolFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|
||||||
void *vextraParams,
|
|
||||||
void *vresult, const Nd4jLong *zShapeInfo,
|
|
||||||
int *dimension, int dimensionLength,
|
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset, int64_t start, int64_t stop) {
|
|
||||||
|
|
||||||
auto x = reinterpret_cast<const X *>(vx);
|
|
||||||
auto z = reinterpret_cast<Z *>(vresult);
|
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
|
||||||
|
|
||||||
auto resultLength = shape::length(zShapeInfo);
|
|
||||||
|
|
||||||
if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) {
|
|
||||||
if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY)
|
|
||||||
return;
|
|
||||||
const auto startingVal = OpType::startingValue(x);
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < resultLength; i++)
|
|
||||||
z[i] = startingVal;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//pre squeezed: this is for keeping the pointer to the original
|
|
||||||
//shape information for tad offset
|
|
||||||
//the squeezed information doesn't render the right strides for
|
|
||||||
//tad offset
|
|
||||||
// || tad.wholeThing
|
|
||||||
if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) {
|
|
||||||
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tadOnlyShapeInfo = tadShapeInfo;
|
|
||||||
auto tadOffsets = tadOffset;
|
|
||||||
|
|
||||||
if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) {
|
|
||||||
if (dimensionLength < 1)
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
|
||||||
tadOnlyShapeInfo = tadPack.primaryShapeInfo();
|
|
||||||
tadOffsets = tadPack.primaryOffsets();
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef INLINE_LOOPS
|
|
||||||
sd::ReductionLoops<X,Z,X>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
|
||||||
#else
|
|
||||||
sd::ReductionBoolLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
void _CUDA_H ReduceBoolFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
|
void _CUDA_H ReduceBoolFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
|
||||||
|
@ -220,7 +155,51 @@ namespace functions {
|
||||||
return OpType::postProcess(intermediate[0], length, extraParams);
|
return OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z>
|
||||||
|
template <typename OpType>
|
||||||
|
void _CUDA_H ReduceBoolFunction<X,Z>::exec(sd::memory::Workspace* workspace, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, const Nd4jLong *zShapeInfo, const int* dims) {
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES);
|
const X* x = reinterpret_cast<const X*>(vx);
|
||||||
|
Z* z = reinterpret_cast<Z*>(vz);
|
||||||
|
X* extraParams = reinterpret_cast<X*>(vextraParams);
|
||||||
|
|
||||||
|
const int xRank = shape::rank(xShapeInfo);
|
||||||
|
const int zRank = shape::rank(zShapeInfo);
|
||||||
|
|
||||||
|
if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) {
|
||||||
|
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
const auto zLen = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < zLen; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (shape::length(zShapeInfo) == 1) {
|
||||||
|
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef INLINE_LOOPS
|
||||||
|
sd::ReductionLoops<X,Z,X>::template loopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
|
#else
|
||||||
|
sd::ReductionBoolLoops<X,Z>::template innerloopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
void ReduceBoolFunction<X,Y>::exec(const int opNum, sd::memory::Workspace* workspace, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, const Nd4jLong *zShapeInfo, const int *dims) {
|
||||||
|
|
||||||
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(workspace, vx, xShapeInfo, vextraParams, vz, zShapeInfo, dims), REDUCE_BOOL_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT ReduceBoolFunction, , LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,11 +26,13 @@
|
||||||
#include <helpers/OmpLaunchHelper.h>
|
#include <helpers/OmpLaunchHelper.h>
|
||||||
#include <helpers/Loops.h>
|
#include <helpers/Loops.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <helpers/ShapeBuilders.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
namespace functions {
|
namespace functions {
|
||||||
namespace reduce {
|
namespace reduce {
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
void _CUDA_H ReduceFloatFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
|
void _CUDA_H ReduceFloatFunction<X,Z>::execScalar(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
@ -133,86 +135,6 @@ namespace functions {
|
||||||
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_FLOAT_OPS);
|
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo), REDUCE_FLOAT_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void ReduceFloatFunction<X, Y>::exec(const int opNum,
|
|
||||||
const void *x, const Nd4jLong *xShapeInfo,
|
|
||||||
void *extraParams,
|
|
||||||
void *z, const Nd4jLong *zShapeInfo,
|
|
||||||
int *dimension, int dimensionLength,
|
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
|
|
||||||
int64_t start, int64_t stop) {
|
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x,
|
|
||||||
xShapeInfo,
|
|
||||||
extraParams,
|
|
||||||
z,
|
|
||||||
zShapeInfo,
|
|
||||||
dimension,
|
|
||||||
dimensionLength,
|
|
||||||
tadShapeInfo,
|
|
||||||
tadOffset, start, stop),
|
|
||||||
REDUCE_FLOAT_OPS);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
|
||||||
template <typename OpType>
|
|
||||||
void _CUDA_H ReduceFloatFunction<X,Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|
||||||
void *vextraParams,
|
|
||||||
void *vresult, const Nd4jLong *zShapeInfo,
|
|
||||||
int *dimension, int dimensionLength,
|
|
||||||
const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffset,
|
|
||||||
int64_t start, int64_t stop) {
|
|
||||||
|
|
||||||
auto x = reinterpret_cast<const X *>(vx);
|
|
||||||
auto z = reinterpret_cast<Z *>(vresult);
|
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
|
||||||
|
|
||||||
auto resultLength = shape::length(zShapeInfo);
|
|
||||||
|
|
||||||
if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) {
|
|
||||||
if(sd::ArrayOptions::arrayType(zShapeInfo) == sd::ArrayType::EMPTY)
|
|
||||||
return;
|
|
||||||
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? sd::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
|
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < resultLength; i++)
|
|
||||||
z[i] = startingVal;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//pre squeezed: this is for keeping the pointer to the original
|
|
||||||
//shape information for tad offset
|
|
||||||
//the squeezed information doesn't render the right strides for
|
|
||||||
//tad offset
|
|
||||||
// || tad.wholeThing
|
|
||||||
if (resultLength == 1 || dimension == nullptr || dimensionLength == shape::rank(xShapeInfo)) {
|
|
||||||
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (OpType::requiresSpecialAccumulation) {
|
|
||||||
OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tadOnlyShapeInfo = tadShapeInfo;
|
|
||||||
auto tadOffsets = tadOffset;
|
|
||||||
|
|
||||||
if (tadOnlyShapeInfo == nullptr || tadOffsets == nullptr) {
|
|
||||||
if (dimensionLength < 0)
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
|
||||||
tadOnlyShapeInfo = tadPack.primaryShapeInfo();
|
|
||||||
tadOffsets = tadPack.primaryOffsets();
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef INLINE_LOOPS
|
|
||||||
sd::ReductionLoops<X,Z,Z>::template loopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
|
||||||
#else
|
|
||||||
sd::ReductionFloatLoops<X,Z>::template innerloopReduce<OpType>(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams, start, stop);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
void _CUDA_H ReduceFloatFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
|
void _CUDA_H ReduceFloatFunction<X,Z>::exec(const void *x, const Nd4jLong *xShapeInfo,
|
||||||
|
@ -255,5 +177,54 @@ namespace functions {
|
||||||
// return result
|
// return result
|
||||||
return OpType::postProcess(intermediate[0], length, extraParams);
|
return OpType::postProcess(intermediate[0], length, extraParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z>
|
||||||
|
template<typename OpType>
|
||||||
|
void _CUDA_H ReduceFloatFunction<X, Z>::exec(sd::memory::Workspace* workspace, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, const Nd4jLong *zShapeInfo, const int* dims) {
|
||||||
|
|
||||||
|
const X* x = reinterpret_cast<const X*>(vx);
|
||||||
|
Z* z = reinterpret_cast<Z*>(vz);
|
||||||
|
Z* extraParams = reinterpret_cast<Z*>(vextraParams);
|
||||||
|
|
||||||
|
const int xRank = shape::rank(xShapeInfo);
|
||||||
|
const int zRank = shape::rank(zShapeInfo);
|
||||||
|
|
||||||
|
if(sd::ArrayOptions::arrayType(xShapeInfo) == sd::ArrayType::EMPTY) {
|
||||||
|
|
||||||
|
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? sd::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
|
||||||
|
const auto zLen = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
for (Nd4jLong i = 0; i < zLen; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (shape::length(zShapeInfo) == 1) {
|
||||||
|
z[0] = execScalar<OpType>(x, xShapeInfo, extraParams);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (OpType::requiresSpecialAccumulation) {
|
||||||
|
OpType::execSpecial(x, xShapeInfo, extraParams, z, zShapeInfo, const_cast<int*>(dims)+zRank, xRank-zRank, nullptr, nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef INLINE_LOOPS
|
||||||
|
sd::ReductionLoops<X,Z,Z>::template loopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
|
#else
|
||||||
|
sd::ReductionFloatLoops<X,Z>::template innerloopReduce<OpType>(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
void ReduceFloatFunction<X, Y>::exec(const int opNum, sd::memory::Workspace* workspace, const void *vx, const Nd4jLong *xShapeInfo, void *vextraParams, void *vz, const Nd4jLong *zShapeInfo, const int *dims) {
|
||||||
|
|
||||||
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(workspace, vx, xShapeInfo, vextraParams, vz, zShapeInfo, dims), REDUCE_FLOAT_OPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue