diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml
index 76251f4cd..7c65b2efb 100644
--- a/arbiter/arbiter-core/pom.xml
+++ b/arbiter/arbiter-core/pom.xml
@@ -99,7 +99,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml
index ec7e22d3c..92e3fb7aa 100644
--- a/arbiter/arbiter-deeplearning4j/pom.xml
+++ b/arbiter/arbiter-deeplearning4j/pom.xml
@@ -77,7 +77,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/arbiter/arbiter-server/pom.xml b/arbiter/arbiter-server/pom.xml
index c4306b967..26aa80b07 100644
--- a/arbiter/arbiter-server/pom.xml
+++ b/arbiter/arbiter-server/pom.xml
@@ -63,7 +63,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml
index 7392392db..88f39a310 100644
--- a/arbiter/arbiter-ui/pom.xml
+++ b/arbiter/arbiter-ui/pom.xml
@@ -37,7 +37,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/arbiter/pom.xml b/arbiter/pom.xml
index 93f877968..ab8d72365 100644
--- a/arbiter/pom.xml
+++ b/arbiter/pom.xml
@@ -151,7 +151,7 @@
${skipTestResourceEnforcement}
- test-nd4j-native,test-nd4j-cuda-10.2
+ test-nd4j-native,test-nd4j-cuda-11.0
false
@@ -333,11 +333,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${nd4j.version}
test
diff --git a/change-cuda-versions.sh b/change-cuda-versions.sh
index 7b354d68b..cbe75d830 100755
--- a/change-cuda-versions.sh
+++ b/change-cuda-versions.sh
@@ -20,7 +20,7 @@
set -e
-VALID_VERSIONS=( 9.2 10.0 10.1 10.2 )
+VALID_VERSIONS=( 9.2 10.0 10.1 10.2 11.0 )
usage() {
echo "Usage: $(basename $0) [-h|--help]
@@ -47,6 +47,10 @@ check_cuda_version() {
check_cuda_version "$VERSION"
case $VERSION in
+ 11.0)
+ VERSION2="8.0"
+ VERSION3="1.5.4-SNAPSHOT"
+ ;;
10.2)
VERSION2="7.6"
VERSION3="1.5.3"
diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml
index 3c3eec86e..0b01863f6 100644
--- a/datavec/datavec-api/pom.xml
+++ b/datavec/datavec-api/pom.xml
@@ -126,7 +126,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml
index 60409bc53..eb61221c8 100644
--- a/datavec/datavec-arrow/pom.xml
+++ b/datavec/datavec-arrow/pom.xml
@@ -62,7 +62,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-data/datavec-data-audio/pom.xml b/datavec/datavec-data/datavec-data-audio/pom.xml
index 3b9674cd9..0c67ae396 100644
--- a/datavec/datavec-data/datavec-data-audio/pom.xml
+++ b/datavec/datavec-data/datavec-data-audio/pom.xml
@@ -79,7 +79,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-data/datavec-data-codec/pom.xml b/datavec/datavec-data/datavec-data-codec/pom.xml
index 7ef25d5a3..58eda4820 100644
--- a/datavec/datavec-data/datavec-data-codec/pom.xml
+++ b/datavec/datavec-data/datavec-data-codec/pom.xml
@@ -66,7 +66,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-data/datavec-data-image/pom.xml b/datavec/datavec-data/datavec-data-image/pom.xml
index aef66381a..c88c89208 100644
--- a/datavec/datavec-data/datavec-data-image/pom.xml
+++ b/datavec/datavec-data/datavec-data-image/pom.xml
@@ -128,7 +128,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-data/datavec-data-nlp/pom.xml b/datavec/datavec-data/datavec-data-nlp/pom.xml
index fb30b93e7..dd5860f9a 100644
--- a/datavec/datavec-data/datavec-data-nlp/pom.xml
+++ b/datavec/datavec-data/datavec-data-nlp/pom.xml
@@ -81,7 +81,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-data/datavec-geo/pom.xml b/datavec/datavec-data/datavec-geo/pom.xml
index f88bc84d1..792265434 100644
--- a/datavec/datavec-data/datavec-geo/pom.xml
+++ b/datavec/datavec-data/datavec-geo/pom.xml
@@ -56,7 +56,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-data/datavec-hadoop/pom.xml b/datavec/datavec-data/datavec-hadoop/pom.xml
index fb7eee69c..4ab64d8b5 100644
--- a/datavec/datavec-data/datavec-hadoop/pom.xml
+++ b/datavec/datavec-data/datavec-hadoop/pom.xml
@@ -74,7 +74,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-data/pom.xml b/datavec/datavec-data/pom.xml
index e40d96149..2510924b8 100644
--- a/datavec/datavec-data/pom.xml
+++ b/datavec/datavec-data/pom.xml
@@ -54,7 +54,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-excel/pom.xml b/datavec/datavec-excel/pom.xml
index 49dc26db8..dc6d3d6b7 100644
--- a/datavec/datavec-excel/pom.xml
+++ b/datavec/datavec-excel/pom.xml
@@ -65,7 +65,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-jdbc/pom.xml b/datavec/datavec-jdbc/pom.xml
index 6ef9b0441..612fb05e7 100644
--- a/datavec/datavec-jdbc/pom.xml
+++ b/datavec/datavec-jdbc/pom.xml
@@ -72,7 +72,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-local/pom.xml b/datavec/datavec-local/pom.xml
index 3adc0e011..06fd13fe2 100644
--- a/datavec/datavec-local/pom.xml
+++ b/datavec/datavec-local/pom.xml
@@ -95,7 +95,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml
index 526b8238a..dae915909 100644
--- a/datavec/datavec-python/pom.xml
+++ b/datavec/datavec-python/pom.xml
@@ -78,7 +78,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
index 3b564b1b3..ff8bdb853 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
@@ -60,7 +60,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml
index bac20d42e..68a78450d 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml
@@ -59,7 +59,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
index 0c05f327b..331c58a8c 100644
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
+++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
@@ -178,7 +178,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-spark-inference-parent/pom.xml b/datavec/datavec-spark-inference-parent/pom.xml
index cc3a6b0c1..5e98a1aad 100644
--- a/datavec/datavec-spark-inference-parent/pom.xml
+++ b/datavec/datavec-spark-inference-parent/pom.xml
@@ -38,7 +38,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml
index 345b774c3..2c547499c 100644
--- a/datavec/datavec-spark/pom.xml
+++ b/datavec/datavec-spark/pom.xml
@@ -144,7 +144,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/datavec/pom.xml b/datavec/pom.xml
index 1c49960d6..8e403ea1e 100644
--- a/datavec/pom.xml
+++ b/datavec/pom.xml
@@ -108,7 +108,7 @@
${skipTestResourceEnforcement}
- test-nd4j-native,test-nd4j-cuda-10.2
+ test-nd4j-native,test-nd4j-cuda-11.0
false
@@ -361,11 +361,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${nd4j.version}
test
diff --git a/deeplearning4j/deeplearning4j-common-tests/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml
index 5a4ba921d..d19aa85c4 100644
--- a/deeplearning4j/deeplearning4j-common-tests/pom.xml
+++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml
@@ -62,11 +62,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-common/pom.xml b/deeplearning4j/deeplearning4j-common/pom.xml
index 1928eff91..7e2e0ebb6 100644
--- a/deeplearning4j/deeplearning4j-common/pom.xml
+++ b/deeplearning4j/deeplearning4j-common/pom.xml
@@ -40,7 +40,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml
index 90c88d4c3..f8dc0123b 100644
--- a/deeplearning4j/deeplearning4j-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-core/pom.xml
@@ -180,11 +180,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-cuda/pom.xml b/deeplearning4j/deeplearning4j-cuda/pom.xml
index 30373db3a..e76e905df 100644
--- a/deeplearning4j/deeplearning4j-cuda/pom.xml
+++ b/deeplearning4j/deeplearning4j-cuda/pom.xml
@@ -16,7 +16,7 @@
4.0.0
- deeplearning4j-cuda-10.2
+ deeplearning4j-cuda-11.0
deeplearning4j-cuda
org.deeplearning4j
@@ -26,9 +26,9 @@
- 10.2
- 7.6
- 1.5.3
+ 11.0
+ 8.0
+ 1.5.4-SNAPSHOT
@@ -112,7 +112,7 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper.java
index a4dddb759..91e3c4829 100644
--- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper.java
+++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper.java
@@ -254,17 +254,34 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
}
} else {
+ /*
code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0, algo1);
+ */
+ val fa = new cudnnConvolutionBwdFilterAlgoPerf_t();
+ val counts = new int[1];
+ code = cudnnFindConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, 1, counts, fa);
+ algo1[0] = fa.algo();
+
checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
+
+ /*
code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc,
mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE
: CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
0, algo2);
+ */
+
+ val da = new cudnnConvolutionBwdDataAlgoPerf_t();
+ code = cudnnFindConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc,
+ cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, 1, counts, da);
+
+ algo2[0] = da.algo();
checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation);
}
@@ -461,11 +478,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
}
} else {
- code = cudnnGetConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc,
+ /*
+ code = cudnnGetConvolutionForwardAlgorithm_v7(cudnnContext, cudnnContext.srcTensorDesc,
cudnnContext.filterDesc, cudnnContext.convDesc,
cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE
- ? CUDNN_CONVOLUTION_FWD_NO_WORKSPACE : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+ ? CUDNN_CONVOLUTION_FWD_ : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0, algo);
+ */
+
+ val cdf = new cudnnConvolutionFwdAlgoPerf_t();
+ val count = new int[1];
+ code = cudnnFindConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, 1, count, cdf);
if(code != CUDNN_STATUS_SUCCESS){
//If CuDNN can't infer algorithm - try IMPLICIT_GEMM
@@ -477,6 +500,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti
fwdAlgo = FwdAlgo.IMPLICIT_GEMM;
algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
+
+ algo[0] = cdf.algo();
}
if(log.isTraceEnabled()){
diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java
index 8d5ece6ff..636071a28 100644
--- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java
+++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/cuda/lstm/ValidateCudnnLSTM.java
@@ -29,6 +29,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@@ -269,7 +270,7 @@ public class ValidateCudnnLSTM extends BaseDL4JTest {
assertTrue(f.get(l0) instanceof CudnnLSTMHelper);
assertTrue(f.get(l1) instanceof CudnnLSTMHelper);
- Random r = new Random(12345);
+ Random r = new Random(123456);
for (int x = 0; x < 1; x++) {
INDArray input = Nd4j.rand(new int[] {minibatch, inputSize, timeSeriesLength});
INDArray labels = Nd4j.zeros(minibatch, nOut, timeSeriesLength);
@@ -284,7 +285,6 @@ public class ValidateCudnnLSTM extends BaseDL4JTest {
mln2.fit(ds);
}
-
assertEquals(mln1.params(), mln2.params());
}
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml
index b1adbe93e..5a48cfc2d 100644
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/pom.xml
@@ -51,7 +51,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml
index c87a94b37..2ce302370 100644
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/pom.xml
@@ -46,7 +46,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml
index 7806bab88..103a1a8e7 100644
--- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml
@@ -43,7 +43,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-data/pom.xml b/deeplearning4j/deeplearning4j-data/pom.xml
index ca29f35b7..bcf4d3355 100644
--- a/deeplearning4j/deeplearning4j-data/pom.xml
+++ b/deeplearning4j/deeplearning4j-data/pom.xml
@@ -38,7 +38,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml
index ed0160ccb..15c3c0715 100644
--- a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml
+++ b/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml
@@ -116,11 +116,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml
index 645b4eca2..876470377 100644
--- a/deeplearning4j/deeplearning4j-graph/pom.xml
+++ b/deeplearning4j/deeplearning4j-graph/pom.xml
@@ -64,7 +64,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml
index 7ebb82e75..86c3e23d6 100644
--- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml
+++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml
@@ -69,7 +69,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-manifold/pom.xml b/deeplearning4j/deeplearning4j-manifold/pom.xml
index 921ee9653..6e1a8cd42 100644
--- a/deeplearning4j/deeplearning4j-manifold/pom.xml
+++ b/deeplearning4j/deeplearning4j-manifold/pom.xml
@@ -41,7 +41,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
index 383eb1c8c..13b7ee45d 100644
--- a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
+++ b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml
@@ -302,11 +302,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml
index 3dcdcc720..4bbb32806 100644
--- a/deeplearning4j/deeplearning4j-modelimport/pom.xml
+++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml
@@ -135,11 +135,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
index 911432cf0..9820c29f2 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
@@ -125,11 +125,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
index 0886e8d5b..eed007f32 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
@@ -49,7 +49,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
index bfd004c41..902a67ae7 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
@@ -53,7 +53,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
index fbe0ddccf..6987dc556 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
@@ -89,11 +89,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
index 23d5d225d..70778f67d 100644
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
+++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml
@@ -44,7 +44,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml
index 23d863cdc..3b0fd8944 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/pom.xml
@@ -72,7 +72,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml
index a4fea6b07..c85e18cdd 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/pom.xml
@@ -75,7 +75,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
index c0fcdb84a..be02f45b6 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/pom.xml
@@ -67,7 +67,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml
index e5ee63ea0..7ec64d395 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml
@@ -110,7 +110,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
index 668c728ae..6595a3a1e 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
@@ -91,7 +91,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nlp-parent/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/pom.xml
index 6c4eea3fd..61627d8a9 100644
--- a/deeplearning4j/deeplearning4j-nlp-parent/pom.xml
+++ b/deeplearning4j/deeplearning4j-nlp-parent/pom.xml
@@ -42,7 +42,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-nn/pom.xml b/deeplearning4j/deeplearning4j-nn/pom.xml
index b817c0dc6..268a70cd9 100644
--- a/deeplearning4j/deeplearning4j-nn/pom.xml
+++ b/deeplearning4j/deeplearning4j-nn/pom.xml
@@ -128,7 +128,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml
index 338d0173f..ca4c49e9e 100644
--- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml
+++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml
@@ -107,14 +107,14 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
false
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/deeplearning4j/deeplearning4j-remote/pom.xml
index 4329a1554..1816689a4 100644
--- a/deeplearning4j/deeplearning4j-remote/pom.xml
+++ b/deeplearning4j/deeplearning4j-remote/pom.xml
@@ -27,7 +27,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml
index 2c7a94de8..18392dfc0 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/pom.xml
@@ -109,11 +109,11 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${project.version}
test
diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml
index 3c083d40d..36a77391e 100644
--- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/pom.xml
@@ -104,7 +104,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/pom.xml b/deeplearning4j/deeplearning4j-scaleout/pom.xml
index 539aa3ef7..d0a199d76 100644
--- a/deeplearning4j/deeplearning4j-scaleout/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/pom.xml
@@ -38,7 +38,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
index 3eafbb9e2..693fbcf7a 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml
@@ -80,7 +80,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
index c4e8dc7ab..ffd7a4b0f 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml
@@ -79,7 +79,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
index 1198ae733..53297cb13 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml
@@ -86,7 +86,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
index b7b20d161..9b399fa22 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml
@@ -105,7 +105,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
index 8a4fb02d5..8bafabc38 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml
@@ -181,7 +181,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml
index 8f83b803e..09f5bb084 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-components/pom.xml
@@ -77,7 +77,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml
index 3c755eac8..6e9cdad17 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/pom.xml
@@ -113,7 +113,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml
index 1b85a1d87..c807b7a46 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-standalone/pom.xml
@@ -32,7 +32,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml
index 44b868ae6..f24bf9109 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml
@@ -67,7 +67,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml
index a66b85ece..835e77fe0 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/pom.xml
@@ -457,7 +457,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-ui-parent/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/pom.xml
index 70c32b984..947a28783 100644
--- a/deeplearning4j/deeplearning4j-ui-parent/pom.xml
+++ b/deeplearning4j/deeplearning4j-ui-parent/pom.xml
@@ -49,7 +49,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/deeplearning4j-zoo/pom.xml b/deeplearning4j/deeplearning4j-zoo/pom.xml
index bec71ec04..ea431ebd9 100644
--- a/deeplearning4j/deeplearning4j-zoo/pom.xml
+++ b/deeplearning4j/deeplearning4j-zoo/pom.xml
@@ -85,7 +85,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
diff --git a/deeplearning4j/dl4j-integration-tests/pom.xml b/deeplearning4j/dl4j-integration-tests/pom.xml
index 43e6bfa60..e8d958cf1 100644
--- a/deeplearning4j/dl4j-integration-tests/pom.xml
+++ b/deeplearning4j/dl4j-integration-tests/pom.xml
@@ -120,7 +120,7 @@
test-nd4j-native
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
\ No newline at end of file
diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml
index 17c89f931..e121e305d 100644
--- a/deeplearning4j/pom.xml
+++ b/deeplearning4j/pom.xml
@@ -225,7 +225,7 @@
${skipBackendChoice}
- test-nd4j-native,test-nd4j-cuda-10.2
+ test-nd4j-native,test-nd4j-cuda-11.0
false
@@ -500,7 +500,7 @@
- test-nd4j-cuda-10.2
+ test-nd4j-cuda-11.0
false
@@ -513,7 +513,7 @@
org.nd4j
- nd4j-cuda-10.2
+ nd4j-cuda-11.0
${nd4j.version}
test
diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h
index 7b32b7d49..67452dda2 100644
--- a/libnd4j/include/array/NDArray.h
+++ b/libnd4j/include/array/NDArray.h
@@ -628,17 +628,17 @@ namespace sd {
* keepDims - if true then put unities in place of reduced dimensions
*/
- NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
- NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
+ NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims = false) const;
+ NDArray reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims = false) const;
- NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
- NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
+ NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims = false) const;
+ NDArray reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims = false) const;
- NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
- NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
+ NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims = false) const;
+ NDArray reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims = false) const;
- NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
- NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false, const bool supportOldShapes = false) const;
+ NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims = false) const;
+ NDArray reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims = false) const;
/**
* method reduces array by excluding its shapes along dimensions present in given dimensions vector
@@ -647,10 +647,10 @@ namespace sd {
* keepDims - if true then put unities in place of reduced dimensions
* extras - extra parameters
*/
- void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
- void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
- void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
- void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool supportOldShapes = false, const bool checkTargetShape = true) const;
+ void reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
+ void reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
+ void reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
+ void reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims = false, const bool checkTargetShape = true) const;
/**
* return variance of array elements set
diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX
index eefe169cf..cfd910343 100644
--- a/libnd4j/include/array/NDArray.hXX
+++ b/libnd4j/include/array/NDArray.hXX
@@ -1353,80 +1353,80 @@ void* NDArray::bufferWithOffset(Nd4jLong offset) {
//////////////////////////////////////////////////////////////////////////
// eventually method reduces array by excluding its shapes along axes present in dimensions vector
-NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const {
+NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims) const {
std::vector copy(dimensions);
- auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, false, getContext()->getWorkspace());
NDArray result(newShape, true, getContext());
- this->reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
+ this->reduceAlongDimension(op, result, copy, keepDims, false);
return result;
}
//////////////////////////////////////////////////////////////////////////
-NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const {
+NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims) const {
std::vector copy(dimensions);
- auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, false, getContext()->getWorkspace());
NDArray result(newShape, true, getContext());
- reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
+ reduceAlongDimension(op, result, copy, keepDims, false);
return result;
}
//////////////////////////////////////////////////////////////////////////
-NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const {
+NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims) const {
std::vector copy(dimensions);
- auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, false, getContext()->getWorkspace());
NDArray result(newShape, true, getContext());
- reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
+ reduceAlongDimension(op, result, copy, keepDims, false);
return result;
}
//////////////////////////////////////////////////////////////////////////
-NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes) const {
+NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims) const {
std::vector copy(dimensions);
- auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, false, getContext()->getWorkspace());
NDArray result(newShape, true, getContext());
- reduceAlongDimension(op, result, copy, keepDims, supportOldShapes, false);
+ reduceAlongDimension(op, result, copy, keepDims, false);
return result;
}
//////////////////////////////////////////////////////////////////////////
// method reduces array by excluding its shapes along axes present in dimensions vector
-NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const {
- return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes);
+NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims) const {
+ return reduceAlongDimension(op, std::vector(dimensions), keepDims);
}
//////////////////////////////////////////////////////////////////////////
-NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const {
- return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes);
+NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims) const {
+ return reduceAlongDimension(op, std::vector(dimensions), keepDims);
}
//////////////////////////////////////////////////////////////////////////
-NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const {
- return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes);
+NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims) const {
+ return reduceAlongDimension(op, std::vector(dimensions), keepDims);
}
//////////////////////////////////////////////////////////////////////////
-NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims, const bool supportOldShapes) const {
- return reduceAlongDimension(op, std::vector(dimensions), keepDims, supportOldShapes);
+NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims) const {
+ return reduceAlongDimension(op, std::vector(dimensions), keepDims);
}
//////////////////////////////////////////////////////////////////////////
@@ -4240,7 +4240,7 @@ NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, cons
//////////////////////////////////////////////////////////////////////////
// method reduces array by excluding its shapes along axes present in dimensions vector
-void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
+void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const {
if (isS())
throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!");
@@ -4250,7 +4250,7 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, con
std::vector copy(dimensions);
if(checkTargetShape) {
- auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
if(!shape::shapeEquals(newShape, target.shapeInfo()))
throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!");
}
@@ -4261,8 +4261,18 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, con
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
}
else {
- auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy);
- NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
+ const Nd4jLong* zShapeInfoH = target.shapeInfo();
+ const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
+
+ if(rankOf() - dimensions.size() != target.rankOf()) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
+ NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
+
}
synchronize("NDArray::reduceAlongDimension FloatOps");
@@ -4271,7 +4281,7 @@ void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, con
//////////////////////////////////////////////////////////////////////////
// method reduces array by excluding its shapes along axes present in dimensions vector
-void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
+void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const {
if (isS())
throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!");
@@ -4281,7 +4291,7 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, cons
std::vector copy(dimensions);
if(checkTargetShape) {
- auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
if(!shape::shapeEquals(newShape, target.shapeInfo()))
throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!");
}
@@ -4291,10 +4301,19 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, cons
if(rankOf() == copy.size() || copy.empty()) {
NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
}
- else { //if (!isEmpty()) {
- auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
- auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
- NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
+ else {
+
+ const Nd4jLong* zShapeInfoH = target.shapeInfo();
+ const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
+
+ if(rankOf() - dimensions.size() != target.rankOf()) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
+ NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
}
synchronize("NDArray::reduceAlongDimension SameOps");
@@ -4303,7 +4322,7 @@ void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, cons
//////////////////////////////////////////////////////////////////////////
// method reduces array by excluding its shapes along axes present in dimensions vector
-void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
+void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const {
if (isS())
throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!");
@@ -4313,7 +4332,7 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, cons
std::vector copy(dimensions);
if(checkTargetShape) {
- auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
if(!shape::shapeEquals(newShape, target.shapeInfo()))
throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!");
}
@@ -4324,9 +4343,17 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, cons
NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
}
else {
- auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
- auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
- NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
+ const Nd4jLong* zShapeInfoH = target.shapeInfo();
+ const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
+
+ if(rankOf() - dimensions.size() != target.rankOf()) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
+ NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
}
synchronize("NDArray::reduceAlongDimension LongOps");
@@ -4335,7 +4362,7 @@ void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, cons
//////////////////////////////////////////////////////////////////////////
// method reduces array by excluding its shapes along axes present in dimensions vector
-void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool supportOldShapes, const bool checkTargetShape) const {
+void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const {
if (isS())
throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!");
@@ -4345,7 +4372,7 @@ void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, cons
std::vector copy(dimensions);
if(checkTargetShape) {
- auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, supportOldShapes, getContext()->getWorkspace());
+ auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace());
if(!shape::shapeEquals(newShape, target.shapeInfo()))
throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!");
}
@@ -4356,9 +4383,17 @@ void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, cons
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
}
else {
- auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr;
- auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), copy);
- NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
+ const Nd4jLong* zShapeInfoH = target.shapeInfo();
+ const Nd4jLong* zShapeInfoD = target.specialShapeInfo();
+
+ if(rankOf() - dimensions.size() != target.rankOf()) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace());
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy);
+ NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size());
}
synchronize("NDArray::reduceAlongDimension LongOps");
diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h
index 25440e05c..65c3bcb99 100644
--- a/libnd4j/include/helpers/ConstantShapeHelper.h
+++ b/libnd4j/include/helpers/ConstantShapeHelper.h
@@ -46,11 +46,13 @@ namespace sd {
static ConstantShapeHelper & getInstance();
- ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape);
- ConstantShapeBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor);
- ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
- ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
- ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector &dimensions = {});
+ ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, const std::vector &shape);
+ ConstantShapeBuffer& bufferForShapeInfo(const ShapeDescriptor &descriptor);
+ ConstantShapeBuffer& bufferForShapeInfo(const Nd4jLong *shapeInfo);
+ ConstantShapeBuffer& bufferForShapeInfo(sd::DataType dataType, char order, int rank, const Nd4jLong* shape);
+ ConstantShapeBuffer& createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector &dimensions = {});
+ ConstantShapeBuffer& createShapeInfoWithNoUnitiesForReduce(const Nd4jLong* maxShapeInfo, const std::vector &dimsWithUnities, sd::memory::Workspace* workspace = nullptr);
+ ConstantShapeBuffer& createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, sd::memory::Workspace* workspace = nullptr);
const Nd4jLong* emptyShapeInfo(sd::DataType dataType);
diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h
index 9bf3daede..325fa3505 100644
--- a/libnd4j/include/helpers/Loops.h
+++ b/libnd4j/include/helpers/Loops.h
@@ -41,43 +41,43 @@ namespace sd {
public:
template
- static FORCEINLINE void loopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, E* extraParams, int64_t start, int64_t stop);
+ static FORCEINLINE void loopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, E* extraParams);
};
template
class ReductionFloatLoops : public ReductionLoops {
public:
- static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
+ static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, Z* extraParams);
template
- static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop);
+ static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, Z* extraParams);
};
template
class ND4J_EXPORT ReductionBoolLoops : public ReductionLoops {
public:
- static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
+ static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
template
- static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
+ static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
};
template
class ND4J_EXPORT ReductionLongLoops : public ReductionLoops {
public:
- static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
+ static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
template
- static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
+ static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
};
template
class ND4J_EXPORT ReductionSameLoops : public ReductionLoops {
public:
- static void wrapper(int opNum, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
+ static void wrapper(int opNum, sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, X* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
template
- static void innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop);
+ static void innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, X* z, const Nd4jLong *zShapeInfo, const int* dims, X* extraParams);
};
@@ -122,372 +122,613 @@ namespace sd {
static void innerloopReduce3All(const X* x, const Nd4jLong* xShapeInfo, const X* y, const Nd4jLong* yShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, Z* extraParams, int64_t start, int64_t stop);
};
+//////////////////////////////////////////////////////////////////////////
+template
+static void reduceExec21(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
+
+ auto func = PRAGMA_THREADS_FOR {
+
+ for (auto i0 = start; i0 < stop; ++i0) {
+
+ auto x0 = x + i0 * xStrd0;
+ auto z0 = z + i0 * zStrd0;
+
+ auto s = OpType::startingValue(x0);
+
+ if(xStrd1 == 1)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ s = OpType::update(s, OpType::op(x0[i1], extraParams), extraParams);
+ else
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ s = OpType::update(s, OpType::op(x0[i1 * xStrd1], extraParams), extraParams);
+
+ *z0 = OpType::postProcess(s, static_cast(xAxis1), extraParams);
+ }
+ };
+
+ samediff::Threads::parallel_for(func, 0,xAxis0);
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+static void reduceExec31(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
+
+ const Nd4jLong tadLen = static_cast(xAxis1 * xAxis2);
+
+ auto func = PRAGMA_THREADS_FOR {
+
+ for (auto i0 = start; i0 < stop; ++i0) {
+
+ auto x0 = x + i0 * xStrd0;
+ auto z0 = z + i0 * zStrd0;
+
+ auto s = OpType::startingValue(x0);
+
+ if(xStrd1 == 1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ s = OpType::update(s, OpType::op(x0[i1 + i2*xStrd2], extraParams), extraParams);
+ else if(xStrd2 == 1)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2], extraParams), extraParams);
+ else
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2], extraParams), extraParams);
+
+ *z0 = OpType::postProcess(s, tadLen, extraParams);
+ }
+ };
+
+ samediff::Threads::parallel_for(func, 0,xAxis0);
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec32(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
+ const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
+ auto func = PRAGMA_THREADS_FOR_2D {
- /*
- //////////////////////////////////////////////////////////////////////////////
- template
- void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
- const Y* y, const Nd4jLong* yShapeInfo,
- Z* z, const Nd4jLong* zShapeInfo,
- Z* extraParams,
- std::function op) {
+ for (auto i0 = start_x; i0 < stop_x; ++i0) {
+ for (auto i1 = start_y; i1 < stop_y; ++i1) {
- const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo);
+ auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
+ auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
- const Nd4jLong* xShape = shape::shapeOf(xShapeInfo);
- const Nd4jLong* xStride = shape::stride(xShapeInfo);
- const Nd4jLong* yStride = shape::stride(yShapeInfo);
- const Nd4jLong* zStride = shape::stride(zShapeInfo);
+ auto s = OpType::startingValue(x1);
- const Nd4jLong len = shape::length(xShapeInfo);
+ if(xStrd2 == 1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x1[i2], extraParams), extraParams);
+ else
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x1[i2 * xStrd2], extraParams), extraParams);
- OmpLaunchHelper threadsInfo(len);
+ *z1 = OpType::postProcess(s, static_cast(xAxis2), extraParams);
+ }
+ }
+ };
- switch (kindOfLoop) {
+ samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1);
+}
- case LoopKind::EWS1: {
- PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads)
- {
- const auto threadNum = omp_get_thread_num();
- const auto threadOffset = threadsInfo.getThreadOffset(threadNum);
- const auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum));
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec41(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
- const auto xi = x + threadOffset;
- const auto yi = y + threadOffset;
- auto zi = z + threadOffset;
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
- PRAGMA_OMP_SIMD
- for (uint i = 0; i < lenPerThread; i++)
- zi[i] = op(xi[i], yi[i], extraParams);
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
+
+ const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
+ const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
+
+ const Nd4jLong tadLen = static_cast(xAxis1 * xAxis2 * xAxis3);
+
+ auto func = PRAGMA_THREADS_FOR {
+
+ for (auto i0 = start; i0 < stop; ++i0) {
+
+ auto x0 = x + i0 * xStrd0;
+ auto z0 = z + i0 * zStrd0;
+
+ auto s = OpType::startingValue(x0);
+
+ if(xStrd1 == 1)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ s = OpType::update(s, OpType::op(x0[i1 + i2*xStrd2 + i3*xStrd3], extraParams), extraParams);
+ else if(xStrd2 == 1)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2 + i3*xStrd3], extraParams), extraParams);
+ else if(xStrd3 == 1)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3], extraParams), extraParams);
+ else
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3*xStrd3], extraParams), extraParams);
+
+ *z0 = OpType::postProcess(s, tadLen, extraParams);
+ }
+ };
+
+ samediff::Threads::parallel_for(func, 0,xAxis0);
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec42(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
+ const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
+
+ const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
+ const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
+
+ const Nd4jLong tadLen = static_cast(xAxis2 * xAxis3);
+
+ auto func = PRAGMA_THREADS_FOR_2D {
+
+ for (auto i0 = start_x; i0 < stop_x; ++i0) {
+ for (auto i1 = start_y; i1 < stop_y; ++i1) {
+
+ auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
+ auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
+
+ auto s = OpType::startingValue(x1);
+
+ if(xStrd2 == 1)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x1[i2 + i3*xStrd3], extraParams), extraParams);
+ else if(xStrd3 == 1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3], extraParams), extraParams);
+ else
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3*xStrd3], extraParams), extraParams);
+
+ *z1 = OpType::postProcess(s, tadLen, extraParams);
+ }
+ }
+ };
+
+ samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1);
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec43(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
+ const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
+ const Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
+
+ const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
+ const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
+
+ auto func = PRAGMA_THREADS_FOR_3D {
+
+ for (auto i0 = start_x; i0 < stop_x; ++i0) {
+ for (auto i1 = start_y; i1 < stop_y; ++i1) {
+ for (auto i2 = start_z; i2 < stop_z; ++i2) {
+
+ auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
+ auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
+
+ auto s = OpType::startingValue(x2);
+
+ if(xStrd3 == 1)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x2[i3], extraParams), extraParams);
+ else
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x2[i3*xStrd3], extraParams), extraParams);
+
+ *z2 = OpType::postProcess(s, static_cast(xAxis3), extraParams);
}
}
- break;
+ }
+ };
- case LoopKind::EWSNONZERO: {
- const uint xEws = shape::elementWiseStride(xShapeInfo);
- const uint yEws = shape::elementWiseStride(yShapeInfo);
- const uint zEws = shape::elementWiseStride(zShapeInfo);
+ samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1, 0,xAxis2,1);
+}
- PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads)
- {
- const auto threadNum = omp_get_thread_num();
- const auto threadOffset = threadsInfo.getThreadOffset(threadNum);
- const auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum));
- const auto xi = x + threadOffset * xEws;
- const auto yi = y + threadOffset * yEws;
- auto zi = z + threadOffset * zEws;
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec51(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
- PRAGMA_OMP_SIMD
- for (uint i = 0; i < lenPerThread; i++)
- zi[i*zEws] = op(xi[i*xEws], yi[i*yEws], extraParams);
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, dims[0]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, dims[0]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
+
+ const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
+ const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
+
+ const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
+ const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
+
+ const Nd4jLong tadLen = static_cast(xAxis1 * xAxis2 * xAxis3 * xAxis4);
+
+ auto func = PRAGMA_THREADS_FOR {
+
+ for (auto i0 = start; i0 < stop; ++i0) {
+
+ auto x0 = x + i0 * xStrd0;
+ auto z0 = z + i0 * zStrd0;
+
+ auto s = OpType::startingValue(x0);
+
+ if(xStrd1 == 1)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ s = OpType::update(s, OpType::op(x0[i1 + i2*xStrd2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
+ else if(xStrd2 == 1)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
+ else if(xStrd3 == 1)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3 + i4*xStrd4], extraParams), extraParams);
+ else if(xStrd4 == 1)
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3*xStrd3 + i4], extraParams), extraParams);
+ else
+ for (uint i1 = 0; i1 < xAxis1; ++i1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x0[i1*xStrd1 + i2*xStrd2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
+
+ *z0 = OpType::postProcess(s, tadLen, extraParams);
+ }
+ };
+
+ samediff::Threads::parallel_for(func, 0,xAxis0);
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec52(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[1]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[0]);
+ const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, dims[2]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, dims[2]);
+
+ const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
+ const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
+
+ const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
+ const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
+
+ const Nd4jLong tadLen = static_cast(xAxis2 * xAxis3 * xAxis4);
+
+ auto func = PRAGMA_THREADS_FOR_2D {
+
+ for (auto i0 = start_x; i0 < stop_x; ++i0) {
+ for (auto i1 = start_y; i1 < stop_y; ++i1) {
+
+ auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
+ auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
+
+ auto s = OpType::startingValue(x1);
+
+ if(xStrd2 == 1)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ s = OpType::update(s, OpType::op(x1[i2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
+ else if(xStrd3 == 1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3 + i4*xStrd4], extraParams), extraParams);
+ else if(xStrd4 == 1)
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3*xStrd3 + i4], extraParams), extraParams);
+ else
+ for (uint i2 = 0; i2 < xAxis2; ++i2)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x1[i2*xStrd2 + i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
+
+ *z1 = OpType::postProcess(s, tadLen, extraParams);
+ }
+ }
+ };
+
+ samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1);
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec53(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[2]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
+
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, dims[1]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, dims[1]);
+ const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
+
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[0]);
+ const Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
+
+ const uint xAxis3 = shape::sizeAt(xShapeInfo, dims[3]);
+ const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, dims[3]);
+
+ const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
+ const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
+
+ const Nd4jLong tadLen = static_cast(xAxis3 * xAxis4);
+
+ auto func = PRAGMA_THREADS_FOR_3D {
+
+ for (auto i0 = start_x; i0 < stop_x; ++i0) {
+ for (auto i1 = start_y; i1 < stop_y; ++i1) {
+ for (auto i2 = start_z; i2 < stop_z; ++i2) {
+
+ auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
+ auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
+
+ auto s = OpType::startingValue(x2);
+
+ if(xStrd3 == 1)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ s = OpType::update(s, OpType::op(x2[i3 + i4*xStrd4], extraParams), extraParams);
+ else if(xStrd4 == 1)
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x2[i3*xStrd3 + i4], extraParams), extraParams);
+ else
+ for (uint i3 = 0; i3 < xAxis3; ++i3)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x2[i3*xStrd3 + i4*xStrd4], extraParams), extraParams);
+
+ *z2 = OpType::postProcess(s, tadLen, extraParams);
}
}
- break;
+ }
+ };
- case LoopKind::RANK1: {
- PRAGMA_OMP_PARALLEL_FOR
- for (uint i0 = 0; i0 < len; ++i0)
- z[i0 * zStride[0]] = op(x[i0 * xStride[0]], y[i0 * yStride[0]], extraParams);
- }
- break;
+ samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1, 0,xAxis2,1);
+}
- case LoopKind::RANK2: {
- PRAGMA_OMP_PARALLEL_FOR_SIMD
- for (uint i0 = 0; i0 < xShape[0]; ++i0)
- for (uint i1 = 0; i1 < xShape[1]; ++i1)
- z[i0 * zStride[0] + i1 * zStride[1]] = op(x[i0 * xStride[0] + i1 * xStride[1]], y[i0 * yStride[0] + i1 * yStride[1]], extraParams);
- }
- break;
+//////////////////////////////////////////////////////////////////////////
+template
+void reduceExec54(const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
- case LoopKind::RANK3: {
- PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2)
- for (uint i0 = 0; i0 < xShape[0]; ++i0)
- for (uint i1 = 0; i1 < xShape[1]; ++i1)
- for (uint i2 = 0; i2 < xShape[2]; ++i2)
- z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]], extraParams);
- }
- break;
+ const uint xAxis0 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[3]);
+ const Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[0] : dims[3]);
+ const Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
- case LoopKind::RANK4: {
- PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(3)
- for (uint i0 = 0; i0 < xShape[0]; ++i0)
- for (uint i1 = 0; i1 < xShape[1]; ++i1)
- for (uint i2 = 0; i2 < xShape[2]; ++i2)
- for (uint i3 = 0; i3 < xShape[3]; ++i3)
- z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]], extraParams);
- }
- break;
+ const uint xAxis1 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[2]);
+ const Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[1] : dims[2]);
+ const Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
- case LoopKind::RANK5: {
- PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(4)
- for (uint i0 = 0; i0 < xShape[0]; ++i0)
- for (uint i1 = 0; i1 < xShape[1]; ++i1)
- for (uint i2 = 0; i2 < xShape[2]; ++i2)
- for (uint i3 = 0; i3 < xShape[3]; ++i3)
- for (uint i4 = 0; i4 < xShape[4]; ++i4)
- z[i0*zStride[0]+i1*zStride[1]+i2*zStride[2]+i3*zStride[3]+i4*zStride[4]] = op(x[i0*xStride[0]+i1*xStride[1]+i2*xStride[2]+i3*xStride[3]+i4*xStride[4]], y[i0*yStride[0]+i1*yStride[1]+i2*yStride[2]+i3*yStride[3]+i4*yStride[4]], extraParams);
- }
- break;
+ const uint xAxis2 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[1]);
+ const Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[2] : dims[1]);
+ const Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
- default: {
- uint xShapeInfoCast[MAX_RANK];
- uint yShapeInfoCast[MAX_RANK];
- uint zShapeInfoCast[MAX_RANK];
+ const uint xAxis3 = shape::sizeAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[3] : dims[0]);
+ const Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? dims[3] : dims[0]);
+ const Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
- bool canCastX = DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
- bool canCastY = DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
- bool canCastZ = DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
+ const uint xAxis4 = shape::sizeAt(xShapeInfo, dims[4]);
+ const Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, dims[4]);
- PRAGMA_OMP_PARALLEL_THREADS(threadsInfo._numThreads)
- {
- auto threadNum = omp_get_thread_num();
- auto threadOffset = threadsInfo.getThreadOffset(threadNum);
- auto lenPerThread = static_cast(threadsInfo.getItersPerThread(threadNum));
- PRAGMA_OMP_SIMD
- for (uint i = 0; i < lenPerThread; i++) {
- auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
- auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
- auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
- z[zOffset] = op(x[xOffset], y[yOffset], extraParams);
+ auto func = PRAGMA_THREADS_FOR_3D {
+
+ for (auto i0 = start_x; i0 < stop_x; ++i0) {
+ for (auto i1 = start_y; i1 < stop_y; ++i1) {
+ for (auto i2 = start_z; i2 < stop_z; ++i2) {
+ for (auto i3 = 0; i3 < xAxis3; ++i3) {
+
+ auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
+ auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
+
+ auto s = OpType::startingValue(x3);
+
+ if(xStrd4 == 1)
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x3[i4], extraParams), extraParams);
+ else
+ for (uint i4 = 0; i4 < xAxis4; ++i4)
+ s = OpType::update(s, OpType::op(x3[i4*xStrd4], extraParams), extraParams);
+
+ *z3 = OpType::postProcess(s, static_cast(xAxis4), extraParams);
}
}
}
}
+ };
+
+ samediff::Threads::parallel_for(func, 0,xAxis0,1, 0,xAxis1,1, 0,xAxis2,1);
+}
+
+
+////////////////////////////////////////////////////////////////////////
+template
+void reduceDefault(sd::memory::Workspace* workspace, const X *x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int *dims, E* extraParams) {
+
+ const int zRank = shape::rank(zShapeInfo);
+ const int tadRank = shape::rank(xShapeInfo) - zRank;
+
+ Nd4jLong* outerXTadShapeInfo = sd::ShapeBuilders::createSubArrShapeInfo(xShapeInfo, dims, zRank);
+ Nd4jLong* innerXTadShapeInfo = sd::ShapeBuilders::createSubArrShapeInfo(xShapeInfo, dims+zRank, tadRank);
+
+ const bool sameOffsets1 = shape::haveSameShapeAndStrides(zShapeInfo, outerXTadShapeInfo);
+ const bool sameOffsets2 = shape::haveSameShapeAndStrides(zShapeInfo, innerXTadShapeInfo);
+
+ const Nd4jLong zLen = shape::length(zShapeInfo);
+ const Nd4jLong tadLen = shape::length(innerXTadShapeInfo);
+
+ Nd4jLong* zOffsets = nullptr;
+ ALLOCATE(zOffsets, workspace, zLen, Nd4jLong);
+ shape::calcOffsets(zShapeInfo, zOffsets);
+
+ Nd4jLong* outerXTadOffsets = zOffsets;
+ if(!sameOffsets1) {
+ ALLOCATE(outerXTadOffsets, workspace, zLen, Nd4jLong);
+ shape::calcOffsets(outerXTadShapeInfo, outerXTadOffsets);
}
- */
-
-
- //////////////////////////////////////////////////////////////////////////////
- template
- template
- void sd::ReductionLoops::loopReduce(const X* x, const Nd4jLong* xShapeInfo,
- Z* z, const Nd4jLong* zShapeInfo,
- const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets,
- E* extraParams,
- int64_t start, int64_t stop) {
-
- const LoopKind::Kind kindOfLoop = LoopKind::deduceKindOfLoopTadXZ(xShapeInfo, zShapeInfo, tadShapeInfo);
-
- const Nd4jLong zLen = shape::length(zShapeInfo);
- const Nd4jLong tadLen = shape::length(tadShapeInfo);
-
- const uint tadEws = shape::elementWiseStride(tadShapeInfo);
- const uint zEws = shape::elementWiseStride(zShapeInfo);
-
- const Nd4jLong* tadShape = shape::shapeOf(tadShapeInfo);
- const Nd4jLong* tadStride = shape::stride(tadShapeInfo);
-
- int numThreads = OmpLaunchHelper::tadThreads(tadLen, zLen);
-
- switch (kindOfLoop) {
-
- //*********************************************//
- // case LoopKind::SMALLARR2DX: {
- // shape::printShapeInfoLinear(xShapeInfo);
- // shape::printShapeInfoLinear(zShapeInfo);
- // const auto xLen = zLen * tadLen;
- // for (uint i = 0; i < xLen; ++i) {
- // const auto zOffset = shape::subArrayOffset(i, xShapeInfo, zShapeInfo, dimsToExclude, dimsLen);
- // const uint tadInd = (i / tadEws) % tadLen;
- // auto startVal = tadInd ? z[zOffset] : static_cast(OpType::startingValue(x));
- // z[zOffset] = OpType::update(startVal, OpType::op(x[i], extraParams), extraParams);
- // if(tadInd == tadLen - 1)
- // z[zOffset] = OpType::postProcess(z[zOffset], tadLen, extraParams);
- // printf("%u - %lld\n", i, zOffset);
- // }
- // }
- case LoopKind::SMALLARR2DX: {
- const auto uTadLen = static_cast(tadLen);
- const auto uZLenMinusOne = static_cast(zLen - 1);
- const auto xLen = static_cast(zLen * uTadLen);
- const auto sv = static_cast(OpType::startingValue(x));
-
- for (uint i = 0; i <= uZLenMinusOne; i++)
- z[i] = OpType::startingValue(x);
-
- uint zOffset = 0;
- for (uint i = 0; i < xLen; ++i) {
- z[zOffset] = OpType::update(z[zOffset], OpType::op(x[i], extraParams), extraParams);
- zOffset = zOffset == uZLenMinusOne ? 0 : zOffset + 1;
- }
-
- for (uint i = 0; i <= uZLenMinusOne; i++)
- z[i] = OpType::postProcess(z[i], tadLen, extraParams);
- }
- break;
-
- //*********************************************//
- case LoopKind::EWS1: {
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong j = 0; j < tadLen; j++)
- s = OpType::update(s, OpType::op(tad[j], extraParams), extraParams);
-
- z[i] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::EWSNONZERO: {
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong j = 0; j < tadLen; j++)
- s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams);
-
- z[i * zEws] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::RANK1: {
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong i0 = 0; i0 < tadLen; ++i0)
- s = OpType::update(s, OpType::op(tad[i0 * tadStride[0]], extraParams), extraParams);
-
- z[i] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::RANK2: {
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
- for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
- s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1]], extraParams), extraParams);
-
- z[i] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::RANK3: {
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
- for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
- for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2)
- s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2]], extraParams), extraParams);
-
- z[i] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::RANK4: {
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
- for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
- for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2)
- for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3)
- s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3]], extraParams), extraParams);
-
- z[i] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::RANK5: {
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong i0 = 0; i0 < tadShape[0]; ++i0)
- for (Nd4jLong i1 = 0; i1 < tadShape[1]; ++i1)
- for (Nd4jLong i2 = 0; i2 < tadShape[2]; ++i2)
- for (Nd4jLong i3 = 0; i3 < tadShape[3]; ++i3)
- for (Nd4jLong i4 = 0; i4 < tadShape[4]; ++i4)
- s = OpType::update(s, OpType::op(tad[i0 * tadStride[0] + i1 * tadStride[1] + i2 * tadStride[2] + i3 * tadStride[3] + i4 * tadStride[4]], extraParams), extraParams);
-
- z[i] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::X_EWSNONZERO: {
- uint castZShapeInfo[MAX_RANK];
- const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo);
-
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong j = 0; j < tadLen; j++)
- s = OpType::update(s, OpType::op(tad[j * tadEws], extraParams), extraParams);
-
- auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
- z[zOffset] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- case LoopKind::Z_EWSNONZERO: {
- uint castTadShapeInfo[MAX_RANK];
- const bool canCastTad = sd::DataTypeUtils::castShapeInfo(tadShapeInfo, castTadShapeInfo);
-
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong j = 0; j < tadLen; j++) {
- auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
- s = OpType::update(s, OpType::op(tad[tadOffset], extraParams), extraParams);
- }
-
- z[i * zEws] = OpType::postProcess(s, tadLen, extraParams);
- };
- }
- break;
-
- //*********************************************//
- default: {
- auto innertadOffsets = new Nd4jLong[tadLen];
- shape::calcOffsets(tadShapeInfo, innertadOffsets);
-
- uint castZShapeInfo[MAX_RANK];
- const bool canCastZ = sd::DataTypeUtils::castShapeInfo(zShapeInfo, castZShapeInfo);
-
- for (auto i = start; i < stop; i++) {
- auto tad = x + tadOffsets[i];
- auto s = OpType::startingValue(tad);
-
- for (Nd4jLong j = 0; j < tadLen; j++)
- s = OpType::update(s, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams);
-
- auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
- z[zOffset] = OpType::postProcess(s, tadLen, extraParams);
- };
-
- delete[] innertadOffsets;
- }
- }
+ Nd4jLong* innerXTadOffsets = zOffsets;
+ if(!sameOffsets2) {
+ ALLOCATE(innerXTadOffsets, workspace, tadLen, Nd4jLong);
+ shape::calcOffsets(innerXTadShapeInfo, innerXTadOffsets);
}
+ auto func = PRAGMA_THREADS_FOR{
+
+ for (auto i = start; i < stop; ++i) {
+
+ const auto tad = x + outerXTadOffsets[i];
+ auto s = OpType::startingValue(tad);
+
+ for (Nd4jLong j = 0; j < tadLen; j++)
+ s = OpType::update(s, OpType::op(tad[innerXTadOffsets[j]], extraParams), extraParams);
+
+ z[zOffsets[i]] = OpType::postProcess(s, tadLen, extraParams);
+ }
+ };
+
+ samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
+
+ RELEASE(outerXTadShapeInfo, workspace);
+ RELEASE(innerXTadShapeInfo, workspace);
+ RELEASE(zOffsets, workspace);
+ if(!sameOffsets1)
+ RELEASE(outerXTadOffsets, workspace);
+ if(!sameOffsets2)
+ RELEASE(innerXTadOffsets, workspace);
+}
+
+//////////////////////////////////////////////////////////////////////////////
+template
+template
+void sd::ReductionLoops::loopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong *xShapeInfo, Z* z, const Nd4jLong *zShapeInfo, const int* dims, E* extraParams) {
+
+ const int xRank = shape::rank(xShapeInfo);
+ const int zRank = shape::rank(zShapeInfo);
+
+ // shape::printShapeInfoLinear(xShapeInfo);
+ // shape::printShapeInfoLinear(zShapeInfo);
+ // shape::printIntArray(dims, shape::rank(xShapeInfo));
+
+ if(xRank == 2 && zRank == 1)
+ reduceExec21(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 3 && zRank == 1)
+ reduceExec31(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 3 && zRank == 2)
+ reduceExec32(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 4 && zRank == 1)
+ reduceExec41(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 4 && zRank == 2)
+ reduceExec42(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 4 && zRank == 3)
+ reduceExec43(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 5 && zRank == 1)
+ reduceExec51(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 5 && zRank == 2)
+ reduceExec52(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 5 && zRank == 3)
+ reduceExec53(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else if(xRank == 5 && zRank == 4)
+ reduceExec54(x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+ else
+ reduceDefault(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
+}
+
//////////////////////////////////////////////////////////////////////////////
diff --git a/libnd4j/include/helpers/ShapeBuilders.h b/libnd4j/include/helpers/ShapeBuilders.h
index e2c29a280..14726d5e6 100644
--- a/libnd4j/include/helpers/ShapeBuilders.h
+++ b/libnd4j/include/helpers/ShapeBuilders.h
@@ -52,11 +52,9 @@ namespace sd {
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
/**
- * allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
- * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
- * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
+ * allocates memory for sub-array shapeInfo and copy shape and strides at axes(positions) stored in dims
*/
- static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
+ static Nd4jLong* createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const sd::DataType dataType, memory::Workspace* workspace = nullptr);
diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h
index cb2faa43d..bd30d9225 100644
--- a/libnd4j/include/helpers/ShapeUtils.h
+++ b/libnd4j/include/helpers/ShapeUtils.h
@@ -40,6 +40,12 @@ namespace sd {
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
static const Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, sd::memory::Workspace* workspace = nullptr);
+
+ // for example
+ // if rank = 3 and dimsToExclude = {0,2} then output = {1,0,2}, if rank = 3 and dimsToExclude = {2} then output = {0,1,2}
+ // if rank = 3 and dimsToExclude = {0} then output = {1,2,0}, if rank = 4 and dimsToExclude = {0,3} then output = {1,2,0,3}
+ static std::vector evalDimsForReduceOp(const int rank, const std::vector& dimsToExclude);
+
/**
* evaluate output shape for reduce operation when input shape is empty
* behavior is analogous to tf
diff --git a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h
index d87c20d3c..c48030542 100644
--- a/libnd4j/include/helpers/benchmark/ReductionBenchmark.h
+++ b/libnd4j/include/helpers/benchmark/ReductionBenchmark.h
@@ -94,9 +94,9 @@ namespace sd {
auto tadOffsets = Environment::getInstance().isCPU() ? pack.primaryOffsets() : pack.specialOffsets();
if (_opType == 0)
- NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets);
+ NativeOpExecutioner::execReduceFloat(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size());
else
- NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size(), tadOnlyShapeInfo, tadOffsets);
+ NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), _opNum, _x->buffer(), _x->shapeInfo(), _x->specialBuffer(), _x->specialShapeInfo(), nullptr, _z->buffer(), _z->shapeInfo(), _z->specialBuffer(), _z->specialShapeInfo(), nullptr, _axis.size());
}
manager.synchronize();
diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp
index 528527f36..f12616688 100644
--- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp
+++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp
@@ -184,6 +184,43 @@ ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast
return bufferForShapeInfo(descriptor);
}
-} // namespace sd
+
+
+////////////////////////////////////////////////////////////////////////
+ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(const Nd4jLong* inShapeInfo, const std::vector &dimsWithUnities, sd::memory::Workspace* workspace) {
+
+ Nd4jLong* newShapeInfo = nullptr;
+ ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(inShapeInfo) - dimsWithUnities.size()), Nd4jLong);
+
+ int temp;
+ if(dimsWithUnities.size() == 1 && shape::isCommonVector(inShapeInfo, temp) && temp == dimsWithUnities[0]) {
+ auto dims = ShapeUtils::evalDimsToExclude(shape::rank(inShapeInfo), {temp});
+ shape::excludeUnitiesFromShapeInfo(inShapeInfo, dims.data(), dims.size(), newShapeInfo);
+ } else {
+ shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsWithUnities.data(), dimsWithUnities.size(), newShapeInfo);
+ }
+
+ ShapeDescriptor descriptor(newShapeInfo);
+
+ RELEASE(newShapeInfo, workspace);
+
+ return bufferForShapeInfo(descriptor);
+}
+
+////////////////////////////////////////////////////////////////////////
+ConstantShapeBuffer& ConstantShapeHelper::createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, sd::memory::Workspace* workspace) {
+
+ Nd4jLong* newShapeInfo = ShapeBuilders::createSubArrShapeInfo(inShapeInfo, dims, dimsSize, workspace);
+
+ ShapeDescriptor descriptor(newShapeInfo);
+
+ RELEASE(newShapeInfo, workspace);
+
+ return bufferForShapeInfo(descriptor);
+}
+
+
+
+}
#endif
\ No newline at end of file
diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp
index 437eebe1d..91467758b 100644
--- a/libnd4j/include/helpers/cpu/MmulHelper.cpp
+++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp
@@ -443,17 +443,17 @@ static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC,
// calculate index of current batch
Nd4jLong batchInd;
if(cRank > 2)
- batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, cBatchDims);
+ batchInd = shape::coords2index(cShapeInfo, cBatchDims, cRank - 2, cCoords.data());
// evaluate A coordinates
if(aRank > 2)
- shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, aBatchDims);
+ shape::index2coords(batchInd, aShapeInfo, aBatchDims, aRank - 2, aCoords.data());
aCoords[aMaxis] = cCoords[cMaxis];
aCoords[aKaxis] = 0;
// evaluate B coordinates
if(bRank > 2)
- shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, bBatchDims);
+ shape::index2coords(batchInd, bShapeInfo, bBatchDims, bRank - 2, bCoords.data());
bCoords[bKaxis] = 0;
bCoords[bNaxis] = cCoords[cNaxis];
diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp
index e122717fc..04e6e4eb5 100644
--- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp
+++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_bool.cpp
@@ -26,20 +26,19 @@ namespace sd {
template
template
- void ReductionBoolLoops::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
+ void ReductionBoolLoops::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const int* dims, X* extraParams) {
#ifndef INLINE_LOOPS
- ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
+ ReductionLoops::template loopReduce(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
#endif
}
template
- void ReductionBoolLoops::wrapper(const int opNum,
+ void ReductionBoolLoops::wrapper(const int opNum, sd::memory::Workspace* workspace,
const X *x, const Nd4jLong *xShapeInfo,
Y *z, const Nd4jLong *zShapeInfo,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
- X *extraParams, int64_t start, int64_t stop) {
+ const int *dims, X *extraParams) {
#ifndef INLINE_LOOPS
- DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_BOOL_OPS);
+ DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_BOOL_OPS);
#endif
}
diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.hpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.hpp
index c7ed544b2..2ab71b34a 100644
--- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.hpp
+++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_float.hpp
@@ -28,20 +28,19 @@ namespace sd {
template
template
- void ReductionFloatLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, Z* extraParams, int64_t start, int64_t stop) {
+ void ReductionFloatLoops::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, Z* z, const Nd4jLong* zShapeInfo, const int* dims, Z* extraParams) {
#ifndef INLINE_LOOPS
- ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
+ ReductionLoops::template loopReduce(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
#endif
}
template
- void ReductionFloatLoops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo,
+ void ReductionFloatLoops::wrapper(const int opNum, sd::memory::Workspace* workspace,
+ const X *x, const Nd4jLong *xShapeInfo,
Y *z, const Nd4jLong *zShapeInfo,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets,
- Y *extraParams,
- int64_t start, int64_t stop) {
+ const int *dims, Y *extraParams) {
#ifndef INLINE_LOOPS
- DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_FLOAT_OPS);
+ DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_FLOAT_OPS);
#endif
}
diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp
index be6cb28bd..820091f09 100644
--- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp
+++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_long.cpp
@@ -33,18 +33,19 @@ namespace sd {
template
template
- void ReductionLongLoops::innerloopReduce(const X * x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
+ void ReductionLongLoops::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, Z *z, const Nd4jLong* zShapeInfo, const int* dims, X* extraParams) {
#ifndef INLINE_LOOPS
- ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
+ ReductionLoops::template loopReduce(workspace, x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
#endif
}
template
- void ReductionLongLoops::wrapper(const int opNum, const X *x, const Nd4jLong *xShapeInfo, Y *z,
- const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
- const Nd4jLong *tadOffsets, X *extraParams, int64_t start, int64_t stop) {
+ void ReductionLongLoops::wrapper(const int opNum, sd::memory::Workspace* workspace,
+ const X *x, const Nd4jLong *xShapeInfo,
+ Y *z, const Nd4jLong *zShapeInfo,
+ const int *dims, X *extraParams) {
#ifndef INLINE_LOOPS
- DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_LONG_OPS);
+ DISPATCH_BY_OPNUM_TT(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_LONG_OPS);
#endif
}
diff --git a/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp
index 53725de83..2544a3c03 100644
--- a/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp
+++ b/libnd4j/include/helpers/cpu/loops/ReductionLoops_same.cpp
@@ -26,23 +26,23 @@ namespace sd {
template
template
- void ReductionSameLoops::innerloopReduce(const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo, const Nd4jLong* tadOffsets, X* extraParams, int64_t start, int64_t stop) {
+ void ReductionSameLoops::innerloopReduce(sd::memory::Workspace* workspace, const X* x, const Nd4jLong* xShapeInfo, X* z, const Nd4jLong* zShapeInfo, const int* dims, X* extraParams) {
#ifndef INLINE_LOOPS
- ReductionLoops::template loopReduce(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop);
+ ReductionLoops::template loopReduce(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams);
#endif
}
template
- void ReductionSameLoops::wrapper(const int opNum, const X *vx, const Nd4jLong *xShapeInfo, X *vz,
- const Nd4jLong *zShapeInfo, const Nd4jLong *tadShapeInfo,
- const Nd4jLong *tadOffsets,
- X *vextraParams, int64_t start, int64_t stop) {
+ void ReductionSameLoops::wrapper(const int opNum, sd::memory::Workspace* workspace,
+ const X *vx, const Nd4jLong *xShapeInfo,
+ X *z, const Nd4jLong *zShapeInfo,
+ const int *dims, X *vextraParams) {
#ifndef INLINE_LOOPS
auto x = reinterpret_cast(vx);
auto z = reinterpret_cast(vz);
auto extraParams = reinterpret_cast(vextraParams);
- DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams, start, stop), REDUCE_SAME_OPS);
+ DISPATCH_BY_OPNUM_T(innerloopReduce, PARAMS(workspace, x, xShapeInfo, z, zShapeInfo, dims, extraParams), REDUCE_SAME_OPS);
#endif
}
diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu
index 35ba60ca9..fb093e7b7 100644
--- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu
+++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu
@@ -24,6 +24,7 @@
#include
#include
#include
+#include
#include
#include
@@ -187,4 +188,38 @@ ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcas
return bufferForShapeInfo(descriptor);
}
+////////////////////////////////////////////////////////////////////////
+ConstantShapeBuffer& ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(const Nd4jLong* inShapeInfo, const std::vector &dimsWithUnities, sd::memory::Workspace* workspace) {
+
+ Nd4jLong* newShapeInfo = nullptr;
+ ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(inShapeInfo) - dimsWithUnities.size()), Nd4jLong);
+
+ int temp;
+ if(dimsWithUnities.size() == 1 && shape::isCommonVector(inShapeInfo, temp) && temp == dimsWithUnities[0]) {
+ auto dims = ShapeUtils::evalDimsToExclude(shape::rank(inShapeInfo), {temp});
+ shape::excludeUnitiesFromShapeInfo(inShapeInfo, dims.data(), dims.size(), newShapeInfo);
+ } else {
+ shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsWithUnities.data(), dimsWithUnities.size(), newShapeInfo);
+ }
+
+ ShapeDescriptor descriptor(newShapeInfo);
+
+ RELEASE(newShapeInfo, workspace);
+
+ return bufferForShapeInfo(descriptor);
+}
+
+////////////////////////////////////////////////////////////////////////
+ConstantShapeBuffer& ConstantShapeHelper::createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, sd::memory::Workspace* workspace) {
+
+ Nd4jLong* newShapeInfo = ShapeBuilders::createSubArrShapeInfo(inShapeInfo, dims, dimsSize, workspace);
+
+ ShapeDescriptor descriptor(newShapeInfo);
+
+ RELEASE(newShapeInfo, workspace);
+
+ return bufferForShapeInfo(descriptor);
+}
+
+
}
\ No newline at end of file
diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu
index d1122d794..36f48184a 100644
--- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu
+++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu
@@ -571,17 +571,17 @@ static __global__ void batchedCudaGemm(const void* vA, const Nd4jLong* aShapeInf
// calculate index of current batch
Nd4jLong batchInd;
if(cBatchDims != nullptr)
- batchInd = shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims);
+ batchInd = shape::coords2index(cShapeInfo, cBatchDims, cRank - 2, cCoords);
// evaluate A coordinates
if(aBatchDims != nullptr)
- shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims);
+ shape::index2coords(batchInd, aShapeInfo, aBatchDims, aRank - 2, aCoords);
aCoords[aMaxis] = cCoords[cMaxis];
aCoords[aKaxis] = 0;
// evaluate B coordinates
if(bBatchDims != nullptr)
- shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims);
+ shape::index2coords(batchInd, bShapeInfo, bBatchDims, bRank - 2, bCoords);
bCoords[bKaxis] = 0;
bCoords[bNaxis] = cCoords[cNaxis];
diff --git a/libnd4j/include/helpers/impl/ShapeBuilders.cpp b/libnd4j/include/helpers/impl/ShapeBuilders.cpp
index 7c0c7fed6..dbcf6dac0 100644
--- a/libnd4j/include/helpers/impl/ShapeBuilders.cpp
+++ b/libnd4j/include/helpers/impl/ShapeBuilders.cpp
@@ -140,14 +140,26 @@ namespace sd {
}
////////////////////////////////////////////////////////////////////////////////
-Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
+Nd4jLong* ShapeBuilders::createSubArrShapeInfo(const Nd4jLong* inShapeInfo, const int* dims, const int dimsSize, memory::Workspace* workspace) {
- Nd4jLong *outShapeInfo = nullptr;
- ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
+ Nd4jLong *subArrShapeInfo = nullptr;
+ ALLOCATE(subArrShapeInfo, workspace, shape::shapeInfoLength(dimsSize), Nd4jLong);
- shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
+ subArrShapeInfo[0] = dimsSize; // rank
+ sd::ArrayOptions::copyDataType(subArrShapeInfo, inShapeInfo); // type
+ subArrShapeInfo[2*dimsSize + 3] = shape::order(inShapeInfo); // order
- return outShapeInfo;
+ Nd4jLong* shape = shape::shapeOf(subArrShapeInfo);
+ Nd4jLong* strides = shape::stride(subArrShapeInfo);
+
+ for(int i = 0; i < dimsSize; ++i) {
+ shape[i] = shape::sizeAt(inShapeInfo, dims[i]);
+ strides[i] = shape::strideAt(inShapeInfo, dims[i]);
+ }
+
+ shape::checkStridesEwsAndOrder(subArrShapeInfo);
+
+ return subArrShapeInfo;
}
}
\ No newline at end of file
diff --git a/libnd4j/include/helpers/impl/ShapeUtils.cpp b/libnd4j/include/helpers/impl/ShapeUtils.cpp
index 2c189cff1..998df7728 100644
--- a/libnd4j/include/helpers/impl/ShapeUtils.cpp
+++ b/libnd4j/include/helpers/impl/ShapeUtils.cpp
@@ -1062,6 +1062,17 @@ bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, const std::vector ShapeUtils::evalDimsForReduceOp(const int rank, const std::vector& dimsToExclude) {
+
+ std::vector output = ShapeUtils::evalDimsToExclude(rank, dimsToExclude);
+
+ for(uint j = 0; j < dimsToExclude.size(); ++j)
+ output.emplace_back(dimsToExclude[j]);
+
+ return output;
+}
+
////////////////////////////////////////////////////////////////////////////////
/*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector& sameDims) {
diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h
index 719b086cb..ca6054482 100644
--- a/libnd4j/include/helpers/shape.h
+++ b/libnd4j/include/helpers/shape.h
@@ -901,6 +901,16 @@ namespace shape {
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *coords, Nd4jLong baseOffset = 0);
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset = 0);
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset = 0);
+ ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, const int* dims); // length of dims is equal to rank of shapeInfo
+
+ // all three arrays should have same rank
+ // all three arrays should have same dimensions or some of them are 1 (that is satisfy broadcasting principle), strides may be different
+ // shapeInfo1 - first array should have max length compared to rest of two arrays
+ ND4J_EXPORT _CUDA_HD void getOffsetBroadcast(const Nd4jLong& startInd, const Nd4jLong ind,
+ const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2, const Nd4jLong* shapeInfo3,
+ const bool sameOffsets12, const bool sameOffsets13,
+ int* coords,
+ Nd4jLong& offset1, Nd4jLong& offset2, Nd4jLong& offset3);
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
@@ -918,11 +928,12 @@ namespace shape {
ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, int *coords);
+ // ND4J_EXPORT _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, const int* dims, Nd4jLong *coords);
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/
- ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims);
+ ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, int *coords);
/**
* Convert coordinates to the corresponding linear index (sequence number in other words)
@@ -935,7 +946,7 @@ namespace shape {
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/
- ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims);
+ ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int* dims, const int dimsSize, const int *coords);
/**
* increment n-dimensional array by one iteration by changing coord appropriately
@@ -951,7 +962,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned);
- ND4J_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo);
+ ND4J_EXPORT _CUDA_HD void printShapeInfo(const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD void printShapeInfoLinear(const Nd4jLong *shapeInfo);
@@ -1057,10 +1068,10 @@ namespace shape {
ND4J_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities);
/**
- * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {1,3}, dimsSize = 2
+ * for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude(points on unity dimensions) = {1,3}, dimsSize = 2
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
*/
- INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo);
+ INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int* dimsToExclude, const int dimsSize, Nd4jLong* outShapeInfo);
/**
* get stride over contiguous axis (contiguous axis must have stride = 1)
@@ -1847,13 +1858,13 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape,
return index;
}
-INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int *coords, const int dimsSize, const int* tadDims) {
+INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, const int *coords) {
Nd4jLong index, shift = 1;;
- index = coords[tadDims[dimsSize - 1]];
- for(uint i = dimsSize - 1; i >= 1; --i) {
- shift *= shapeInfo[tadDims[i]];
+ index = coords[dims[dimsLen - 1]];
+ for(uint i = dimsLen - 1; i >= 1; --i) {
+ shift *= shapeInfo[dims[i]];
index += shift * coords[i - 1];
}
@@ -3324,6 +3335,18 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong
return offset;
}
+//////////////////////////////////////////////////////////////////////////
+INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) {
+
+ Nd4jLong offset = baseOffset;
+
+ for(uint i = 1; i <= shapeInfo[0]; ++i)
+ if(shapeInfo[i] != 1)
+ offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
+
+ return offset;
+}
+
//////////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, Nd4jLong baseOffset) {
@@ -3337,17 +3360,78 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coor
}
//////////////////////////////////////////////////////////////////////////
-INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coords, Nd4jLong baseOffset) {
+INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const int *coords, const int* dims) {
- Nd4jLong offset = baseOffset;
+ Nd4jLong offset = 0;
for(uint i = 1; i <= shapeInfo[0]; ++i)
if(shapeInfo[i] != 1)
- offset += coords[i - 1] * shapeInfo[shapeInfo[0] + i];
+ offset += coords[dims[i - 1]] * shapeInfo[shapeInfo[0] + i];
return offset;
}
+//////////////////////////////////////////////////////////////////////
+INLINEDEF _CUDA_HD void getOffsetBroadcast(const Nd4jLong& startInd, const Nd4jLong ind,
+ const Nd4jLong* shapeInfo1, const Nd4jLong* shapeInfo2, const Nd4jLong* shapeInfo3,
+ const bool sameOffsets12, const bool sameOffsets13,
+ int* coords,
+ Nd4jLong& offset1, Nd4jLong& offset2, Nd4jLong& offset3) {
+
+ const Nd4jLong* shape1 = shape::shapeOf(shapeInfo1);
+ const Nd4jLong* strides1 = shape::stride(shapeInfo1);
+ const Nd4jLong* shape2 = shape::shapeOf(shapeInfo2);
+ const Nd4jLong* strides2 = shape::stride(shapeInfo2);
+ const Nd4jLong* shape3 = shape::shapeOf(shapeInfo3);
+ const Nd4jLong* strides3 = shape::stride(shapeInfo3);
+
+ if(startInd == ind) {
+
+ if(shape::rank(shapeInfo1) == 0) {
+ offset1 = offset2 = offset3 = 0;
+ return;
+ }
+
+ shape::index2coords(ind, shapeInfo1, coords);
+ offset1 = shape::getOffset(shapeInfo1, coords);
+
+ if(sameOffsets12)
+ offset2 = offset1;
+ else
+ offset2 = shape::getOffset(shapeInfo2, coords);
+
+ if(sameOffsets13)
+ offset3 = offset1;
+ else
+ offset3 = shape::getOffset(shapeInfo3, coords);
+
+ return;
+ }
+
+ int axis = shapeInfo1[0] - 1;
+ while(coords[axis] == shape1[axis] - 1) {
+ if(!sameOffsets12 && shape2[axis] != 1)
+ offset2 -= (shape2[axis] - 1) * strides2[axis];
+ if(!sameOffsets13 && shape3[axis] != 1)
+ offset3 -= (shape3[axis] - 1) * strides3[axis];
+ if(shape1[axis] != 1)
+ offset1 -= (shape1[axis] - 1) * strides1[axis];
+ coords[axis--] = 0;
+ }
+
+ ++coords[axis];
+ offset1 += strides1[axis];
+
+ if(!sameOffsets12 && shape2[axis] != 1)
+ offset2 += strides2[axis];
+ if(!sameOffsets13 && shape3[axis] != 1)
+ offset3 += strides3[axis];
+
+ if(sameOffsets12)
+ offset2 = offset1;
+ if(sameOffsets13)
+ offset3 = offset1;
+}
/**
* Returns the tensor along dimension
@@ -3443,7 +3527,7 @@ INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const uint *coo
printf("\n");
}
- INLINEDEF _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo) {
+ INLINEDEF _CUDA_HD void printShapeInfo(const Nd4jLong *shapeInfo) {
int rank = shape::rank(shapeInfo);
Nd4jLong *shape = shape::shapeOf(shapeInfo);
printf("Rank %d\n",rank);
@@ -4583,89 +4667,92 @@ INLINEDEF void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const c
//////////////////////////////////////////////////////////////////////
INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order) {
- // if(false) { // tests showed that this code did calculation notably slower even for big N
- // Nd4jLong indexes[MAX_RANK];
- // PRAGMA_OMP_PARALLEL_FOR_ARGS(private(indexes))
- // for (Nd4jLong i = 0; i < N; ++i) {
- // shape::index2coords(rank, shape, i, indexes);
- // subArrOffsets[i] = 0;
- // for (int j = 0; j < rank; ++j)
- // if(shape[j] != 1)
- // subArrOffsets[i] += indexes[j] * strides[j];
- // }
- // return;
- // }
+ const uint64_t len = shape::prodLong(shape, rank);
// set offset for first sub-array, it is equal to zero always
offsets[0] = 0;
- Nd4jLong * idx = new Nd4jLong[rank];
- Nd4jLong* offsetPerDim = new Nd4jLong[rank];
- memset(idx, 0, sizeof(Nd4jLong) * rank);
+ uint coords[MAX_RANK];
+ memset(coords, 0, sizeof(uint) * rank);
- PRAGMA_OMP_SIMD
- for (int k = 0; k < rank; ++k)
- offsetPerDim[k] = (shape[k] - 1) * strides[k];
-
- Nd4jLong init = 0, i = 1;
- // nested loops - calculation of sub-array offsets
if(order == 'c') {
- Nd4jLong rankMinusOne = rank - 1, j = rankMinusOne;
-
- while(j >= 0) {
-
- if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity
-
- if(j == rankMinusOne) { // last dimension
- for(int l = 1; l < shape[j]; ++l) {
- offsets[i] = offsets[i - 1] + strides[j];
- i++;
- }
- --j;
- }
- else if(idx[j] < shape[j] - 1) {
- init += strides[j];
- offsets[i++] = init;
- ++idx[j];
- j = rankMinusOne;
- }
- else {
- init -= offsetPerDim[j];
- idx[j--] = 0;
+ for (uint64_t i = 1; i < len; ++i) {
+ int axis = rank - 1;
+ offsets[i] = 0;
+ while(coords[axis] == shape[axis] - 1) {
+ offsets[i] -= (shape[axis] - 1) * strides[axis];
+ coords[axis--] = 0;
}
+ ++coords[axis];
+ offsets[i] += offsets[i-1] + strides[axis];
}
- }
- else {
+ } else {
- Nd4jLong j = 0;
-
- while(j < rank) {
-
- if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity
-
- if(j == 0) { // last dimension
- for(int l = 1; l < shape[j]; ++l) {
- offsets[i] = offsets[i - 1] + strides[j];
- i++;
- }
- ++j;
- }
- else if(idx[j] < shape[j] - 1) {
- init += strides[j];
- offsets[i++] = init;
- ++idx[j];
- j = 0;
- }
- else {
- init -= offsetPerDim[j];
- idx[j++] = 0;
+ for (uint64_t i = 1; i < len; ++i) {
+ int axis = 0;
+ offsets[i] = 0;
+ while(coords[axis] == shape[axis] - 1) {
+ offsets[i] -= (shape[axis] - 1) * strides[axis];
+ coords[axis++] = 0;
}
+ ++coords[axis];
+ offsets[i] += offsets[i-1] + strides[axis];
}
}
- delete []idx;
- delete []offsetPerDim;
+ // Nd4jLong init = 0, i = 1;
+ // // nested loops - calculation of sub-array offsets
+ // if(order == 'c') {
+
+ // int rankMinusOne = rank - 1, j = rankMinusOne;
+
+ // while(j >= 0) {
+
+ // if(shape[j] == 1) { --j; continue; } // ignore dimensions equal to unity
+
+ // if(j == rankMinusOne) { // last dimension
+ // for(uint l = 1; l < shape[j]; ++l)
+ // offsets[i++] = offsets[i - 1] + strides[j];
+ // --j;
+ // }
+ // else if(coords[j] < shape[j] - 1) {
+ // init += strides[j];
+ // offsets[i++] = init;
+ // ++coords[j];
+ // j = rankMinusOne;
+ // }
+ // else {
+ // init -= (shape[j] - 1) * strides[j];
+ // coords[j--] = 0;
+ // }
+ // }
+ // }
+ // else {
+
+ // int j = 0;
+
+ // while(j < rank) {
+
+ // if(shape[j] == 1) { ++j; continue; } // ignore dimensions equal to unity
+
+ // if(j == 0) { // last dimension
+ // for(uint l = 1; l < shape[j]; ++l)
+ // offsets[i++] = offsets[i - 1] + strides[j];
+ // ++j;
+ // }
+ // else if(coords[j] < shape[j] - 1) {
+ // init += strides[j];
+ // offsets[i++] = init;
+ // ++coords[j];
+ // j = 0;
+ // }
+ // else {
+ // init -= (shape[j] - 1) * strides[j];
+ // coords[j++] = 0;
+ // }
+ // }
+ // }
}
//////////////////////////////////////////////////////////////////////
@@ -4884,13 +4971,14 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jL
}
//////////////////////////////////////////////////////////////////////
-INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords, const int dimsSize, const int* tadDims) {
+INLINEDEF _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, int *coords) {
- for(uint i = dimsSize - 1; i > 0; --i) {
- coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]];
- index /= shapeInfo[1 + tadDims[i]];
+ for(uint i = dimsLen - 1; i > 0; --i) {
+ const auto ind = dims[i];
+ coords[ind] = index % shapeInfo[1 + ind];
+ index /= shapeInfo[1 + ind];
}
- coords[tadDims[0]] = index; // last iteration
+ coords[dims[0]] = index; // last iteration
}
//////////////////////////////////////////////////////////////////////
@@ -4921,6 +5009,64 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo
}
}
+//////////////////////////////////////////////////////////////////////
+INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) {
+
+ const int rank = shape::rank(inShapeInfo);
+ const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo));
+
+ if(numOfNonUnities == rank) { // no unities in shape, no copy procedure
+ shapeNoUnities = const_cast(inShapeInfo) + 1;
+ stridesNoUnities = const_cast(inShapeInfo) + 1 + rank;
+ return numOfNonUnities;
+ }
+
+ for(uint j = 0, i = 0; i < rank; ++i) {
+ if(shape::shapeOf(inShapeInfo)[i] != 1) {
+ shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i];
+ shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i];
+ }
+ }
+
+ stridesNoUnities = shapeNoUnities + numOfNonUnities;
+
+ return numOfNonUnities;
+}
+
+//////////////////////////////////////////////////////////////////////
+INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int* dimsToExclude, const int dimsSize, Nd4jLong* outShapeInfo) {
+
+ outShapeInfo[0] = inShapeInfo[0] - dimsSize;
+
+ for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) {
+ if(j < dimsSize && i == dimsToExclude[j]) {
+ ++j;
+ continue;
+ }
+
+ shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i];
+ shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i];
+ }
+
+ sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type
+ *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
+ outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
+}
+
+//////////////////////////////////////////////////////////////////////
+// INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLong& index, const Nd4jLong *shapeInfo, const int* dims, const int dimsLen, int *coords) {
+
+// if(startIndex == index) {
+// shape::index2coords(index, shapeInfo, dims, dimsLen, coords);
+// }
+// else {
+// int i = dimsLen - 1;
+// while(coords[dims[i]] == shape::sizeAt(shapeInfo, dims[i]) - 1)
+// coords[dims[i--]] = 0;
+// ++coords[dims[i]];
+// }
+// }
+
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@@ -5111,50 +5257,6 @@ INLINEDEF _CUDA_HD void index2coordsCPU(const Nd4jLong& startIndex, const Nd4jLo
// }
// }
-//////////////////////////////////////////////////////////////////////
-INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) {
-
- const int rank = shape::rank(inShapeInfo);
- const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo));
-
- if(numOfNonUnities == rank) { // no unities in shape, no copy procedure
- shapeNoUnities = const_cast(inShapeInfo) + 1;
- stridesNoUnities = const_cast(inShapeInfo) + 1 + rank;
- return numOfNonUnities;
- }
-
- for(uint j = 0, i = 0; i < rank; ++i) {
- if(shape::shapeOf(inShapeInfo)[i] != 1) {
- shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i];
- shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i];
- }
- }
-
- stridesNoUnities = shapeNoUnities + numOfNonUnities;
-
- return numOfNonUnities;
-}
-
-//////////////////////////////////////////////////////////////////////
-INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo) {
-
- outShapeInfo[0] = inShapeInfo[0] - dimsSize;
-
- for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) {
- if(j < dimsSize && i == dimsToExclude[j]) {
- ++j;
- continue;
- }
-
- shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i];
- shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i];
- }
-
- sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type
- *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
- outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
-}
-
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
diff --git a/libnd4j/include/legacy/NativeOpExecutioner.h b/libnd4j/include/legacy/NativeOpExecutioner.h
index 84ab886c4..20f0b6532 100644
--- a/libnd4j/include/legacy/NativeOpExecutioner.h
+++ b/libnd4j/include/legacy/NativeOpExecutioner.h
@@ -470,8 +470,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
+ int *dimension, int dimensionLength);
static void execReduceSame(sd::LaunchContext *lc,
int opNum,
@@ -480,8 +479,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
+ int *dimension, int dimensionLength);
static void execReduceBool(sd::LaunchContext *lc,
int opNum,
@@ -490,8 +488,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
+ int *dimension, int dimensionLength);
static void execReduceLong(sd::LaunchContext *lc,
int opNum,
@@ -500,8 +497,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets);
+ int *dimension, int dimensionLength);
/**
*
diff --git a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp
index 6b6c51a13..be8a0fbb3 100644
--- a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp
+++ b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp
@@ -585,8 +585,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
+ int *dimension, int dimensionLength) {
@@ -597,13 +596,7 @@ void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
if (shape::isEmpty(hZShapeInfo))
return;
- auto func = PRAGMA_THREADS_FOR {
- BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, FLOAT_TYPES);
- };
-
- const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
-
- samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
+ BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES, FLOAT_TYPES);
}
////////////////////////////////////////////////////////////////////////
@@ -614,24 +607,16 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
+ int *dimension, int dimensionLength) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
- auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
// nothing to do here if result is empty
if (shape::isEmpty(hZShapeInfo))
return;
- auto func = PRAGMA_THREADS_FOR {
- BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES);
- };
-
- const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
-
- samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
+ BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES);
}
////////////////////////////////////////////////////////////////////////
@@ -642,8 +627,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
+ int *dimension, int dimensionLength) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@@ -653,13 +637,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
if (shape::isEmpty(hZShapeInfo))
return;
- auto func = PRAGMA_THREADS_FOR {
- BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, BOOL_TYPES);
- };
-
- const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
-
- samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
+ BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES, BOOL_TYPES);
}
////////////////////////////////////////////////////////////////////////
@@ -670,8 +648,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
void *extraParams,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
- int *dimension, int dimensionLength,
- const Nd4jLong *tadShapeInfo, const Nd4jLong *tadOffsets) {
+ int *dimension, int dimensionLength) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
@@ -681,13 +658,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
if (shape::isEmpty(hZShapeInfo))
return;
- auto func = PRAGMA_THREADS_FOR {
- BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, start, stop), LIBND4J_TYPES, LONG_TYPES);
- };
-
- const sd::LoopKind::Kind kindOfLoop = sd::LoopKind::deduceKindOfLoopTadXZ(hXShapeInfo, hZShapeInfo, tadShapeInfo);
-
- samediff::Threads::parallel_tad(func, 0, shape::length(hZShapeInfo), 1, kindOfLoop == sd::LoopKind::Kind::SMALLARR2DX ? 1 : sd::Environment::getInstance().maxMasterThreads());
+ BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::exec(opNum, lc ? lc->getWorkspace() : nullptr, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, dimension), LIBND4J_TYPES, LONG_TYPES);
}
////////////////////////////////////////////////////////////////////////
diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp
index f9e3f669c..463adc17e 100644
--- a/libnd4j/include/legacy/cpu/NativeOps.cpp
+++ b/libnd4j/include/legacy/cpu/NativeOps.cpp
@@ -447,28 +447,26 @@ void execReduceFloat2(Nd4jPointer *extraPointers,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
try {
+
auto dimension = reinterpret_cast(dbDimension->primary());
auto dimensionLength = static_cast(shape::length(hDimensionShape));
- auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
+ const auto zLen = shape::length(hZShapeInfo);
- auto hTADShapeInfo = tadPackX.primaryShapeInfo();
- auto hTADOffsets = tadPackX.primaryOffsets();
+ std::vector dimensions(dimension, dimension + dimensionLength);
+
+ const Nd4jLong* zShapeInfoH = hZShapeInfo;
+ const Nd4jLong* zShapeInfoD = dZShapeInfo;
+
+ if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector();
+ NativeOpExecutioner::execReduceFloat(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
- NativeOpExecutioner::execReduceFloat(nullptr, opNum,
- dbX->primary(),
- hXShapeInfo,
- dbX->special(),
- dXShapeInfo,
- extraParams,
- dbZ->primary(),
- hZShapeInfo,
- dbZ->special(),
- dZShapeInfo,
- dimension,
- dimensionLength,
- hTADShapeInfo,
- hTADOffsets);
} catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
@@ -481,30 +479,27 @@ void execReduceBool2(Nd4jPointer *extraPointers,
void *extraParams,
OpaqueDataBuffer *dbZ, const Nd4jLong *hZShapeInfo, const Nd4jLong *dZShapeInfo,
OpaqueDataBuffer *dbDimension, const Nd4jLong *hDimensionShape, const Nd4jLong *dDimensionShape) {
+
try {
auto dimension = reinterpret_cast(dbDimension->primary());
auto dimensionLength = static_cast(shape::length(hDimensionShape));
- auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
- dimensionLength);
+ std::vector dimensions(dimension, dimension + dimensionLength);
- auto hTADShapeInfo = tadPack.primaryShapeInfo();
- auto hTADOffsets = tadPack.primaryOffsets();
+ const auto zLen = shape::length(hZShapeInfo);
+
+ const Nd4jLong* zShapeInfoH = hZShapeInfo;
+ const Nd4jLong* zShapeInfoD = dZShapeInfo;
+
+ if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo)) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector();
+ NativeOpExecutioner::execReduceBool(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
- NativeOpExecutioner::execReduceBool(nullptr, opNum,
- dbX->primary(),
- hXShapeInfo,
- dbX->special(),
- dXShapeInfo,
- extraParams,
- dbZ->primary(),
- hZShapeInfo,
- dbZ->special(),
- dZShapeInfo,
- dimension,
- dimensionLength,
- hTADShapeInfo,
- hTADOffsets);
} catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
@@ -521,26 +516,22 @@ void execReduceSame2(Nd4jPointer *extraPointers,
auto dimension = reinterpret_cast(dbDimension->primary());
int dimensionLength = static_cast(shape::length(hDimensionShape));
- auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension,
- dimensionLength);
+ std::vector dimensions(dimension, dimension + dimensionLength);
- auto hTADShapeInfo = tadPack.primaryShapeInfo();
- auto hTADOffsets = tadPack.primaryOffsets();
+ const auto zLen = shape::length(hZShapeInfo);
+
+ const Nd4jLong* zShapeInfoH = hZShapeInfo;
+ const Nd4jLong* zShapeInfoD = dZShapeInfo;
+
+ if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector();
+ NativeOpExecutioner::execReduceSame(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
- NativeOpExecutioner::execReduceSame(nullptr, opNum,
- dbX->primary(),
- hXShapeInfo,
- dbX->special(),
- dXShapeInfo,
- extraParams,
- dbZ->primary(),
- hZShapeInfo,
- dbZ->special(),
- dZShapeInfo,
- dimension,
- dimensionLength,
- hTADShapeInfo,
- hTADOffsets);
} catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
@@ -557,25 +548,22 @@ void execReduceLong2(Nd4jPointer *extraPointers,
auto dimension = reinterpret_cast(dbDimension->primary());
int dimensionLength = static_cast(shape::length(hDimensionShape));
- auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo, dimension, dimensionLength);
+ std::vector dimensions(dimension, dimension + dimensionLength);
- auto hTADShapeInfo = tadPack.primaryShapeInfo();
- auto hTADOffsets = tadPack.primaryOffsets();
+ const auto zLen = shape::length(hZShapeInfo);
+
+ const Nd4jLong* zShapeInfoH = hZShapeInfo;
+ const Nd4jLong* zShapeInfoD = dZShapeInfo;
+
+ if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ zShapeInfoD = reinterpret_cast(zPack.special());
+ }
+
+ std::vector dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector();
+ NativeOpExecutioner::execReduceLong(nullptr, opNum, dbX->primary(), hXShapeInfo, dbX->special(), dXShapeInfo, extraParams, dbZ->primary(), zShapeInfoH, dbZ->special(), zShapeInfoD, dims.data(), dims.size());
- NativeOpExecutioner::execReduceLong(nullptr, opNum,
- dbX->primary(),
- hXShapeInfo,
- dbX->special(),
- dXShapeInfo,
- extraParams,
- dbZ->primary(),
- hZShapeInfo,
- dbZ->special(),
- dZShapeInfo,
- dimension,
- dimensionLength,
- hTADShapeInfo,
- hTADOffsets);
} catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
diff --git a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu
index 14cbf306a..cb3c78238 100644
--- a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu
+++ b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu
@@ -210,7 +210,7 @@ void NativeOpExecutioner::execSummaryStatsScalar(sd::LaunchContext *lc,
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
- dim3 launchDims = dim3(256, 256, 32768);
+ dim3 launchDims = dim3(256, CUDA_BLOCK_SIZE, 1024);
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@@ -577,8 +577,7 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
void *extraParams,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
- int *dimension, int dimensionLength,
- Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
+ int *dimension, int dimensionLength) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@@ -588,15 +587,14 @@ void NativeOpExecutioner::execReduceSame(sd::LaunchContext *lc,
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
- auto xRank = shape::rank(hXShapeInfo);
if (zType != xType)
throw datatype_exception::build("NativeOpExecutioner::execReduceSame requires both X & Z operands to have same type", xType, zType);
auto numBlocks = shape::length(hZShapeInfo);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 8192);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
- BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
+ BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
@@ -612,8 +610,7 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
void *extraParams,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
- int *dimension,int dimensionLength,
- Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
+ int *dimension,int dimensionLength) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@@ -627,11 +624,10 @@ void NativeOpExecutioner::execReduceLong(sd::LaunchContext *lc,
if (zType != sd::DataType::INT64)
throw datatype_exception::build("NativeOpExecutioner::execReduceLong wrong Z data type", sd::DataType::INT64, zType);
- auto xRank = shape::rank(hXShapeInfo);
auto numBlocks = shape::length(hZShapeInfo);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
- BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES);
+ BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES, LONG_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
@@ -648,8 +644,7 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
void *extraParams,
void *hZ, Nd4jLong const* hZShapeInfo,
void *dZ, Nd4jLong const* dZShapeInfo,
- int *dimension, int dimensionLength,
- Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
+ int *dimension, int dimensionLength) {
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
@@ -663,11 +658,10 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
if (zType != sd::DataType::BOOL)
throw std::runtime_error("NativeOpExecutioner::execReduceBool requires Z operand to have BOOL type");
- auto xRank = shape::rank(hXShapeInfo);
auto numBlocks = shape::length(hZShapeInfo);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
- BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES);
+ BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES, BOOL_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
@@ -675,6 +669,45 @@ void NativeOpExecutioner::execReduceBool(sd::LaunchContext *lc,
throw cuda_exception::build("execReduceBool failed", res);
}
+////////////////////////////////////////////////////////////////////////
+/**
+ *
+ * @param opNum
+ * @param dX
+ * @param dXShapeInfo
+ * @param extraParams
+ * @param dZ
+ * @param dZShapeInfo
+ */
+void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
+ int opNum,
+ const void *hX, const Nd4jLong *hXShapeInfo,
+ const void *dX, const Nd4jLong *dXShapeInfo,
+ void *extraParams,
+ void *hZ, const Nd4jLong *hZShapeInfo,
+ void *dZ, const Nd4jLong *dZShapeInfo,
+ int *dimension, int dimensionLength) {
+
+ auto stream = lc->getCudaStream();
+ auto reductionPointer = lc->getReductionPointer();
+
+ if (sd::Environment::getInstance().isDebugAndVerbose())
+ printf("F8 opNum:[%i]\n", opNum);
+
+ auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
+ auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
+
+ auto numBlocks = shape::length(hZShapeInfo);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
+
+ BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, reductionPointer, dZ, dZShapeInfo, hZShapeInfo, dimension), LIBND4J_TYPES, FLOAT_TYPES);
+
+ // TODO: remove after the release
+ auto res = cudaStreamSynchronize(*stream);
+ if (res != 0)
+ throw cuda_exception::build("execReduceFloat failed", res);
+}
+
////////////////////////////////////////////////////////////////////////
/**
*
@@ -707,7 +740,8 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
auto numBlocks = shape::length(hZShapeInfo);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
+ auto tadLength = shape::length(hXShapeInfo) / numBlocks;
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, tadLength < CUDA_BLOCK_SIZE ? tadLength : CUDA_BLOCK_SIZE, 1024);
if (zType != sd::DataType::INT64 && zType != sd::DataType::INT32)
throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType);
@@ -722,46 +756,6 @@ void NativeOpExecutioner::execIndexReduce(sd::LaunchContext *lc,
throw cuda_exception::build("execIndexReduce failed", res);
}
-////////////////////////////////////////////////////////////////////////
-/**
- *
- * @param opNum
- * @param dX
- * @param dXShapeInfo
- * @param extraParams
- * @param dZ
- * @param dZShapeInfo
- */
-void NativeOpExecutioner::execReduceFloat(sd::LaunchContext *lc,
- int opNum,
- void const* hX, Nd4jLong const* hXShapeInfo,
- void const* dX, Nd4jLong const* dXShapeInfo,
- void *extraParams,
- void *hZ, Nd4jLong const* hZShapeInfo,
- void *dZ, Nd4jLong const* dZShapeInfo,
- int *dimension,int dimensionLength,
- Nd4jLong const* tadShapeInfo, Nd4jLong const* tadOffsets) {
-
- auto stream = lc->getCudaStream();
- auto reductionPointer = lc->getReductionPointer();
-
- if (sd::Environment::getInstance().isDebugAndVerbose())
- printf("F8 opNum:[%i]\n", opNum);
-
- auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
- auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
-
- auto xRank = shape::rank(hXShapeInfo);
- auto numBlocks = shape::length(hZShapeInfo);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
-
- BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
-
- // TODO: remove after the release
- auto res = cudaStreamSynchronize(*stream);
- if (res != 0)
- throw cuda_exception::build("execReduceFloat failed", res);
-}
/**
@@ -790,7 +784,7 @@ void NativeOpExecutioner::execIndexReduceScalar(sd::LaunchContext *lc,
auto xLength = shape::length(hXShapeInfo);
auto blockWidth = 256;
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
if (sd::Environment::getInstance().isDebugAndVerbose() && launchDims.x == 1)
printf("AF1 opNum:[%i]\n", opNum);
@@ -840,7 +834,7 @@ void NativeOpExecutioner::execReduceFloatScalar(sd::LaunchContext *lc,
auto xLength = shape::length(hXShapeInfo);
auto blockWidth = 256;
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, hXShapeInfo, extraParams, dZ,dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
@@ -870,9 +864,9 @@ void NativeOpExecutioner::execReduceBoolScalar(sd::LaunchContext *lc,
throw std::runtime_error("NativeOpExecutioner::execReduceBoolScalar requires Z operand to have BOOL type");
auto xLength = shape::length(hXShapeInfo);
- auto blockWidth = 256;
+ auto blockWidth = CUDA_BLOCK_SIZE;
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES);
@@ -901,9 +895,9 @@ void NativeOpExecutioner::execReduceSameScalar(sd::LaunchContext *lc,
throw datatype_exception::build("NativeOpExecutioner::execReduceSameScalar requires both X & Z operands to have same type", xType, zType);
auto xLength = shape::length(hXShapeInfo);
- auto blockWidth = 256;
+ auto blockWidth = CUDA_BLOCK_SIZE;
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES);
@@ -932,9 +926,9 @@ void NativeOpExecutioner::execReduceLongScalar(sd::LaunchContext *lc,
throw datatype_exception::build("NativeOpExecutioner::execReduceLongScalar wrong Z data type", sd::DataType::INT64, zType);
auto xLength = shape::length(hXShapeInfo);
- auto blockWidth = 256;
+ auto blockWidth = CUDA_BLOCK_SIZE;
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES);
@@ -1128,7 +1122,7 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
- dim3 launchDims = dim3(256, 256, 32768);
+ dim3 launchDims = dim3(256, CUDA_BLOCK_SIZE, 1024);
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@@ -1158,7 +1152,7 @@ void NativeOpExecutioner::execSummaryStats(sd::LaunchContext *lc,
auto stream = lc->getCudaStream();
auto reductionPointer = lc->getReductionPointer();
- dim3 launchDims = dim3(256, 256, 32768);
+ dim3 launchDims = dim3(256, CUDA_BLOCK_SIZE, 1024);
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
@@ -1194,9 +1188,9 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
- auto blockWidth = 256;
+ auto blockWidth = CUDA_BLOCK_SIZE;
auto numBlocks = CudaLaunchHelper::getReductionBlocks(shape::length(hXShapeInfo), blockWidth);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
if (xType != yType)
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3 requires Y operand to have X type", xType, yType);
@@ -1246,7 +1240,7 @@ void NativeOpExecutioner::execReduce3(sd::LaunchContext *lc,
auto numBlocks = shape::length(hZShapeInfo);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum,
dX, dXShapeInfo,
@@ -1286,9 +1280,9 @@ void NativeOpExecutioner::execReduce3Scalar(sd::LaunchContext *lc,
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
auto xLength = shape::length(hXShapeInfo);
- auto blockWidth = 256;
+ auto blockWidth = CUDA_BLOCK_SIZE;
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, blockWidth, 1024);
if (xType != yType)
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Y operand to have X type", xType, yType);
@@ -1652,7 +1646,7 @@ void NativeOpExecutioner::execReduce3All(sd::LaunchContext *lc,
if (sd::Environment::getInstance().isDebugAndVerbose())
printf("D119 opNum:[%i]\n", opNum);
- dim3 launchDims(shape::length(hZShapeInfo), 256, 32768);
+ dim3 launchDims(shape::length(hZShapeInfo), CUDA_BLOCK_SIZE / 2, 1024);
if (sd::Environment::getInstance().isVerbose() && launchDims.x == 1)
printf("AD119 opNum:[%i]\n", opNum);
@@ -1706,7 +1700,7 @@ void NativeOpExecutioner::execReduce3TAD(sd::LaunchContext *lc,
throw sd::datatype_exception::build("NativeOpExecutioner::execReduce3TAD requires Z operand to have floating point data type", zType);
auto numBlocks = shape::length(hZShapeInfo);
- dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, 256, 32768);
+ dim3 launchDims(numBlocks == 0 ? 1 : numBlocks, CUDA_BLOCK_SIZE, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu
index 1ccc2c7d5..186b3a9cb 100755
--- a/libnd4j/include/legacy/cuda/NativeOps.cu
+++ b/libnd4j/include/legacy/cuda/NativeOps.cu
@@ -454,17 +454,24 @@ void execReduceSame2(Nd4jPointer *extraPointers,
auto dimension = reinterpret_cast(dbDimension->primary());
int dimensionLength = static_cast(shape::length(hDimensionShape));
- auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
- dimension,
- shape::length(hDimensionShape));
+ const auto zLen = shape::length(hZShapeInfo);
+ std::vector dimensions(dimension, dimension + dimensionLength);
+
+ const Nd4jLong* zShapeInfoH = hZShapeInfo;
+
+ if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ }
+
+ std::vector dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector();
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execReduceSame(&lc, opNum,
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
extraParams,
- dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
- dimension, dimensionLength,
- tadPack.specialShapeInfo(), tadPack.specialOffsets());
+ dbZ->primary(), zShapeInfoH, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(zShapeInfoH).special(),
+ dims.data(), dims.size());
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
} catch (std::exception &e) {
@@ -487,17 +494,25 @@ void execReduceLong2(Nd4jPointer *extraPointers,
auto dimension = reinterpret_cast(dbDimension->primary());
int dimensionLength = static_cast(shape::length(hDimensionShape));
- auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(hXShapeInfo,
- dimension,
- shape::length(hDimensionShape));
+ const auto zLen = shape::length(hZShapeInfo);
+
+ std::vector dimensions(dimension, dimension + dimensionLength);
+
+ const Nd4jLong* zShapeInfoH = hZShapeInfo;
+
+ if(shape::rank(hXShapeInfo) - dimensionLength != shape::rank(hZShapeInfo) && zLen != 1) {
+ auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(hZShapeInfo, dimensions);
+ zShapeInfoH = reinterpret_cast(zPack.primary());
+ }
+
+ std::vector dims = (zLen != 1) ? ShapeUtils::evalDimsForReduceOp(shape::rank(hXShapeInfo), dimensions) : std::vector();
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execReduceLong(&lc, opNum,
dbX->primary(), hXShapeInfo, dbX->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hXShapeInfo).special(),
extraParams,
- dbZ->primary(), hZShapeInfo, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(hZShapeInfo).special(),
- dimension, dimensionLength,
- tadPack.specialShapeInfo(), tadPack.specialOffsets());
+ dbZ->primary(), zShapeInfoH, dbZ->special(), ConstantShapeHelper::getInstance().bufferForShapeInfo(zShapeInfoH).special(),
+ dims.data(), dims.size());
InteropDataBuffer::registerSpecialUse({dbZ}, {dbX});
} catch (std::exception &e) {
@@ -562,17 +577,25 @@ void execReduceBool2(Nd4jPointer *extraPointers,
auto dimension = reinterpret_cast(dbDimension->primary());
int dimensionLength = static_cast