From 66b84b38cf294d5ebee60e6d113a9edf215a9745 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Tue, 19 Nov 2019 15:39:36 +0200 Subject: [PATCH] Shyrma mmul (#58) * - get rid of some copy procedures in mmulHelper ops Signed-off-by: Yurii * - further work on embedding cuda api for batched gemm (cublasGemmBatchedEx) in our mmulHelper class Signed-off-by: Yurii * - further work on cuda batched gamm api Signed-off-by: Yurii * - write own cuda kernel performing batched gemm Signed-off-by: Yurii * missing include in MmulHelper Signed-off-by: raver119 * - forgot to keep in code previous correct kernels for mmulNxN, since it may happen that new onw will fail for some reason in future Signed-off-by: Yurii * disable old tensordot Signed-off-by: raver119 * - rewrite cuda kernels for usualGemm and usualGemv Signed-off-by: Yurii * - profiling mmul helpers Signed-off-by: Yurii * - prints to check shapes were added Signed-off-by: Yurii * - correct type of output array Cin mmulNxN Signed-off-by: Yurii * - take into account possible nans in C array Signed-off-by: Yurii * slightly change numThreads message Signed-off-by: raver119 * - make corrections in accordance to given notes in pr review Signed-off-by: Yurii --- libnd4j/blas/NDArray.h | 5 + libnd4j/blas/NDArray.hpp | 16 +- libnd4j/include/helpers/MmulHelper.h | 3 +- .../include/helpers/cpu/ConstantHelper.cpp | 1 + libnd4j/include/helpers/cpu/MmulHelper.cpp | 692 +++++++--- .../include/helpers/cuda_off/MmulHelper.cu | 1159 +++++++++++++---- libnd4j/include/helpers/impl/MmulHelper.cpp | 63 - libnd4j/include/helpers/shape.h | 41 +- .../declarable/helpers/cuda/convolutions.cu | 1 + .../ops/declarable/helpers/cuda/transforms.cu | 1 + .../tests_cpu/layers_tests/HelpersTests1.cpp | 2 + .../tests_cpu/layers_tests/NDArrayTests.cpp | 2 +- .../layers_tests/PerformanceTests.cpp | 6 +- .../layers_tests/PlaygroundTests.cpp | 25 + .../org/nd4j/nativeblas/NativeOpsHolder.java | 2 +- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 19 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 19 + 17 files changed, 1540 insertions(+), 517 deletions(-) diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index de2488f9d..cfad05b49 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -1286,6 +1286,11 @@ namespace nd4j { */ Nd4jLong sizeAt(const int dim) const; + /** + * returns stride of "dim" dimension + */ + Nd4jLong strideAt(const int dim) const; + /** * returns order of array */ diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index c4a631cf5..a83472899 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -1439,9 +1439,21 @@ Nd4jLong NDArray::sizeAt(const int dim) const { throw std::runtime_error("Bad size index requested"); if (dim >= 0) - return this->_shapeInfo[1+dim]; + return shape::shapeOf(_shapeInfo)[dim]; else - return this->_shapeInfo[1+(this->rankOf() + dim)]; + return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; +} + +////////////////////////////////////////////////////////////////////////// +Nd4jLong NDArray::strideAt(const int dim) const { + + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::strideAt: Bad size index requested"); + + if (dim >= 0) + return shape::stride(_shapeInfo)[dim]; + else + return shape::stride(_shapeInfo)[this->rankOf() + dim]; } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/helpers/MmulHelper.h b/libnd4j/include/helpers/MmulHelper.h index 891525a73..ff0a7d1b2 100644 --- a/libnd4j/include/helpers/MmulHelper.h +++ b/libnd4j/include/helpers/MmulHelper.h @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -53,7 +54,7 @@ namespace nd4j { #ifndef __JAVACPP_HACK__ /** - * modif - (can be empty) vector containing a subsequence of permutation/reshaping arrays (in any order), user must take care of correctness of such arrays by himself + * modif - (can be empty) vector containing a subsequence of permutation/reshaping arrays (in any order), user must take care of correctness of such arrays by himself */ static void tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector>& modifA, const std::vector>& modifB, const std::vector>& modifC); static nd4j::NDArray* tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector>& modifA, const std::vector>& modifB); diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index b2549e93f..2ba2cc4e0 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index fca40d564..47fc3bd9f 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -21,6 +21,7 @@ #include "../MmulHelper.h" #include #include +#include #include #include @@ -28,110 +29,124 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////////// -// MXK x KxN = MxN +// MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static void usualGemm(const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static void usualGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { - T1* A = reinterpret_cast(const_cast(vA)); - T2* B = reinterpret_cast(const_cast(vB)); - T3* C = reinterpret_cast(vC); - T3 alphaZ(alpha), betaZ(beta); - - const bool flagC = cOrder == 'f'; - const bool flagA = (flagC && transA) || (!flagC && !transA); - const bool flagB = (flagC && transB) || (!flagC && !transB); + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); - // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) - // for(uint row = 0; row < M; ++row) { + const T3 alphaZ = alpha; + const T3 betaZ = beta; - // T3* c = flagC ? (C + row) : (C + row * ldc); + const bool betaPersent = beta; - // for(uint col = 0; col < N; ++col) - // c[flagC ? col * ldc : col] = 0; + const Nd4jLong* aShapeInfo = vA->getShapeInfo(); + const Nd4jLong* bShapeInfo = vB->getShapeInfo(); + const Nd4jLong* cShapeInfo = vC->getShapeInfo(); - // for(uint i = 0; i < K; ++i) { - - // T3* b = flagB ? (B + i * ldb) : (B + i); - // T3* a = flagA ? (A + row * lda + i) : (A + row + i * lda); + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); - // if(flagC) { - // PRAGMA_OMP_SIMD - // for(uint col = 0; col < N; ++col) { - // if(betaZ) - // c[col * ldc] += a * b[flagB ? col : col * ldb] + betaZ * c[col * ldc]; - // else - // c[col * ldc] += a * b[flagB ? col : col * ldb]; - // } - // } - // else { - // PRAGMA_OMP_SIMD - // for(uint col = 0; col < N; ++col) { - // if(betaZ) - // c[col] += a * b[flagB ? col : col * ldb] + betaZ * c[col]; - // else - // c[col] += a * b[flagB ? col : col * ldb]; - // } - // } - // } - // } + const Nd4jLong cLen = vC->lengthOf(); - auto func = PRAGMA_THREADS_FOR_2D { ; - for (auto row = start_x; row < stop_x; row += inc_x) { - for (auto col = start_y; col < stop_y; col += inc_y) { - T3 *c = flagC ? (C + row + col * ldc) : (C + row * ldc + col); - T3 val = 0; - - PRAGMA_OMP_SIMD - for (uint i = 0; i < K; ++i) { - T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda); - T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i); - val += alphaZ * a * b; - } - - if (betaZ) - *c = val + betaZ * *c; - else - *c = val; - } - } - }; - - samediff::Threads::parallel_for(func, 0, M, 1, 0, N, 1); -} - -////////////////////////////////////////////////////////////////////////////// -// MXN x N = M -template -static void usualGemv(const char aOrder, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { - - T1* A = reinterpret_cast(const_cast(vA)); - T2* X = reinterpret_cast(const_cast(vX)); - T3* Y = reinterpret_cast(vY); - T3 alphaZ(alpha), betaZ(beta); - - const bool flagA = aOrder == 'f'; + const int K = vA->sizeAt(aKaxis); auto func = PRAGMA_THREADS_FOR { - for (auto row = start; row < stop; row += increment) { - T3 *y = Y + row * incy; - T3 val = 0; + std::vector aCoords(2), bCoords(2), cCoords(2); - PRAGMA_OMP_SIMD - for (int i = 0; i < N; ++i) { - T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i); - T3 x = *(X + i * incx); - val += alphaZ * a * x; + for (auto i = start; i < stop; ++i) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords.data()); + + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; } - if (betaZ) - *y = val + betaZ * *y; + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + + if(betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; else - *y = val; + C[cOffset] = alphaZ * val; } }; - samediff::Threads::parallel_for(func, 0, M); + samediff::Threads::parallel_tad(func, 0, cLen); +} + + +////////////////////////////////////////////////////////////////////////////// +// MXN x N = M -> actual sequence of {M,N} axes doesn't matter +template +static void usualGemv(const NDArray* vA, const NDArray* vX, NDArray* vY, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { + + const T1* A = vA->bufferAsT(); + const T2* X = vX->bufferAsT(); + T3* Y = vY->bufferAsT(); + + const T3 alphaZ = alpha; + const T3 betaZ = beta; + + const bool betaPersent = beta; + + const Nd4jLong* aShapeInfo = vA->getShapeInfo(); + const Nd4jLong* xShapeInfo = vX->getShapeInfo(); + const Nd4jLong* yShapeInfo = vY->getShapeInfo(); + + const int N = vX->lengthOf(); + const int M = vY->lengthOf(); + + const auto aMstride = vA->strideAt(aMaxis); + const auto aNstride = vA->strideAt(aMaxis == 0 ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) { + + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; + + T3 val = A[aOffset] * X[xOffset]; // first iteration + + for (uint j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; + } + + auto yOffset = i * incy; + + if(betaPersent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } + }; + + samediff::Threads::parallel_tad(func, 0, M); } ////////////////////////////////////////////////////////////////////////////// @@ -144,12 +159,17 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX, T3* Z = reinterpret_cast(vZ); T3 alphaZ(alpha), betaZ(beta); + const bool betaPersent = beta; + T3 sum = 0; PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) for(int i = 0; i < length; ++i) sum += X[i * incx] * Y[i * incy]; - - *Z = alphaZ * sum + betaZ * *Z; + + if(betaPersent) + *Z = alphaZ * sum + betaZ * *Z; + else + *Z = alphaZ * sum; } ////////////////////////////////////////////////////////////////////////////// @@ -164,16 +184,15 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con if(A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxM: rank of A array is not equal 2 !"); if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !"); + throw std::runtime_error("MmulHelper::mmulMxM: rank of B array is not equal 2 !"); const auto M = A->sizeAt(0); const auto K = A->sizeAt(1); const auto N = B->sizeAt(1); - const auto bRows = B->sizeAt(0); - + if(C != nullptr && C->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxM: rank of C array is not equal 2 !"); - if(bRows != K) + if(B->sizeAt(0) != K) throw std::runtime_error("MmulHelper::mmulMxM: B array has wrong number of rows !"); if(C != nullptr && C->sizeAt(0) != M) throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of rows !"); @@ -181,61 +200,79 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con throw std::runtime_error("MmulHelper::mmulMxM: C array has wrong number of columns !"); if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); - NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); - - const auto cOrder = C->ordering(); - - if(A->ews() != 1) - pA = pA->dup(cOrder); - if(B->ews() != 1) - pB = pB->dup(cOrder); - if(C->ews() != 1) - pC = pC->dup(cOrder); - - const auto aOrder = pA->ordering(); - const auto bOrder = pB->ordering(); - - const bool transA = aOrder != cOrder; - const bool transB = bOrder != cOrder; - - const CBLAS_ORDER blasOrder = cOrder == 'f' ? CblasColMajor : CblasRowMajor; - const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; - const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; - - const int lda = aOrder == 'f' ? M : K; - const int ldb = bOrder == 'f' ? K : N; - const int ldc = cOrder == 'f' ? M : N; - - const auto aType = pA->dataType(); - const auto bType = pB->dataType(); - const auto cType = pC->dataType(); + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); const bool hasGemm = BlasHelper::getInstance()->hasGEMM(aType); - - // we'll use platform-specific gemm here eventually. maybe tomorrow. - // TODO: put proper _gemm here - if (ABC && hasGemm && aType == DataType::FLOAT32) { - BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (float) beta, reinterpret_cast(pC->getBuffer()), ldc); - } - else if (ABC && hasGemm && aType == DataType::DOUBLE) { - BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (double) beta, reinterpret_cast(pC->getBuffer()), ldc); + + const bool typeDouble = hasGemm && ABC && aType == DataType::DOUBLE; + const bool typeFloat = hasGemm && ABC && aType == DataType::FLOAT32; + + if(!typeFloat && !typeDouble) { + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (A, B, C, 0, 1, 0, 1, 0, 1, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } else { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), NUMERIC_TYPES); - //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (cOrder, transA, transB, M, N, K, alpha, pA->getBuffer(), lda, pB->getBuffer(), ldb, beta, pC->getBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); - } - if(pC != C) { - C->assign(pC); - delete pC; + std::vector toDelete; + + NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; + + if(!aMcont && !aKcont) { + pA = A->dup('f'); + toDelete.push_back(pA); + aMcont = true; + } + if(!bKcont && !bNcont) { + pB = B->dup('f'); + toDelete.push_back(pB); + bKcont = true; + } + if(!cMcont && !cNcont) { + pC = C->dup('f'); + toDelete.push_back(pC); + cMcont = true; + } + + const CBLAS_ORDER blasOrder = cMcont ? CblasColMajor : CblasRowMajor; + + const bool transA = (!aMcont && cMcont) || (aMcont && !cMcont); + const bool transB = (!bKcont && cMcont) || (bKcont && !cMcont); + + const CBLAS_TRANSPOSE transAblas = transA ? CblasTrans : CblasNoTrans; + const CBLAS_TRANSPOSE transBblas = transB ? CblasTrans : CblasNoTrans; + + const int lda = (aMcont && aKcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = (bKcont && bNcont) ? K : !bKcont ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = (cMcont && cNcont) ? M : !cMcont ? pC->strideAt(0) : pC->strideAt(1); + + if(typeFloat) { + BlasHelper::getInstance()->sgemm()(blasOrder, transAblas, transBblas, M, N, K, (float) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (float) beta, reinterpret_cast(pC->getBuffer()), ldc); + } + else if(typeDouble) { + BlasHelper::getInstance()->dgemm()(blasOrder, transAblas, transBblas, M, N, K, (double) alpha, reinterpret_cast(pA->getBuffer()), lda, reinterpret_cast(pB->getBuffer()), ldb, (double) beta, reinterpret_cast(pC->getBuffer()), ldc); + } + + if(pC != C) { + C->assign(pC); + delete pC; + } + if(pA != A) + delete pA; + if(pB != B) + delete pB; } - if(pA != A) - delete pA; - if(pB != B) - delete pB; return C; } @@ -243,6 +280,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con //////////////////////////////////////////////////////////////////////////// // MXN x N = M NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { + if (X->dataType() != A->dataType()) throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), X->dataType()); @@ -254,56 +292,65 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* if(A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !"); if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); + throw std::runtime_error("MmulHelper::mmulMxV: X array must be vector !"); - const auto M = A->sizeAt(0); + const auto M = A->sizeAt(0); const auto N = A->sizeAt(1); - + if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) throw std::runtime_error("MmulHelper::mmulMxV: Y array must be vector !"); if(X->lengthOf() != N) throw std::runtime_error("MmulHelper::mmulMxV: X vector has wrong length !"); if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); + throw std::runtime_error("MmulHelper::mmulMxV: Y array has wrong length !"); - if(Y == nullptr) + if(Y == nullptr) Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - NDArray *pA(const_cast(A)); - if(A->ews() != 1) - pA = pA->dup(); - - CBLAS_ORDER blasOrder; - int lda; - if (pA->ordering() == 'f') {blasOrder = CblasColMajor; lda = M; } - else {blasOrder = CblasRowMajor; lda = N; } - const int incx = X->stridesOf()[xLenDim]; const int incy = Y->stridesOf()[yLenDim]; - const auto aType = pA->dataType(); + const auto aType = A->dataType(); const auto xType = X->dataType(); const auto yType = Y->dataType(); const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); const bool hasGemv = BlasHelper::getInstance()->hasGEMV(aType); - - // choose appropriate cuda gemm api depending on data types - if(AXY && hasGemv && aType == DataType::DOUBLE) { - BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->getBuffer(), lda, (double*)X->getBuffer(), incx, beta, (double*)Y->getBuffer(), incy); - } - else if(AXY && hasGemv && aType == DataType::FLOAT32) { - BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy); + + const bool typeDouble = hasGemv && AXY && aType == DataType::DOUBLE; + const bool typeFloat = hasGemv && AXY && aType == DataType::FLOAT32; + + if(!typeDouble && !typeFloat) { + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), NUMERIC_TYPES); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (A, X, Y, incx, incy, 0, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } else { - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), NUMERIC_TYPES); - //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (pA->ordering(), M, N, alpha, pA->getBuffer(), lda, X->getBuffer(), incx, beta, Y->getBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + + NDArray *pA(const_cast(A)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + + if(!aMcont && !aNcont) { + pA = A->dup('f'); + aMcont = true; + } + const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; + + const int lda = (aMcont && aNcont) ? M : !aMcont ? pA->strideAt(0) : pA->strideAt(1); + + // choose appropriate cuda gemm api depending on data types + if(typeDouble) { + BlasHelper::getInstance()->dgemv()(blasOrder, CblasNoTrans, M, N, alpha, (double*)pA->getBuffer(), lda, (double*)X->getBuffer(), incx, beta, (double*)Y->getBuffer(), incy); + } + else if(typeFloat) { + BlasHelper::getInstance()->sgemv()(blasOrder, CblasNoTrans, M, N, (float)alpha, (float*)pA->getBuffer(), lda, (float*)X->getBuffer(), incx, (float)beta, (float*)Y->getBuffer(), incy); + } + + if(pA != A) + delete pA; } - if(pA != A) - delete pA; - return Y; } @@ -330,22 +377,327 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c if(Y->lengthOf() != length) throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); - if(Z == nullptr) + if(Z == nullptr) Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - + const Nd4jLong incx = X->stridesOf()[xLenDim]; const Nd4jLong incy = Y->stridesOf()[yLenDim]; const auto xType = X->dataType(); const auto yType = Y->dataType(); const auto zType = Z->dataType(); - + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), NUMERIC_TYPES); //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (length, alpha, X->getBuffer(), incx, Y->getBuffer(), incy, beta, Z->getBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); return Z; } +////////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +template +static void batchedGemm(const NDArray* vA, const NDArray* vB, NDArray* vC, + const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + + const T1* A = vA->bufferAsT(); + const T2* B = vB->bufferAsT(); + T3* C = vC->bufferAsT(); + + const T3 alphaZ = alpha; + const T3 betaZ = beta; + + const bool betaPersent = beta; + + const Nd4jLong* aShapeInfo = vA->getShapeInfo(); + const Nd4jLong* bShapeInfo = vB->getShapeInfo(); + const Nd4jLong* cShapeInfo = vC->getShapeInfo(); + + const int aRank = vA->rankOf(); + const int bRank = vB->rankOf(); + const int cRank = vC->rankOf(); + + const Nd4jLong cLen = vC->lengthOf(); + + const int K = vA->sizeAt(aKaxis); + + auto func = PRAGMA_THREADS_FOR { + + std::vector aCoords(aRank), bCoords(bRank), cCoords(cRank); + + for (auto i = start; i < stop; ++i) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords.data()); + + // calculate index of current batch + Nd4jLong batchInd; + if(cRank > 2) + batchInd = shape::coords2index(cShapeInfo, cCoords.data(), cRank - 2, cBatchDims); + + // evaluate A coordinates + if(aRank > 2) + shape::index2coords(batchInd, aShapeInfo, aCoords.data(), aRank - 2, aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + if(bRank > 2) + shape::index2coords(batchInd, bShapeInfo, bCoords.data(), bRank - 2, bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords.data()); + auto bOffset = shape::getOffset(bShapeInfo, bCoords.data()); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords.data()); + + if(betaPersent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } + }; + + samediff::Threads::parallel_tad(func, 0, cLen); +} + +////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + const int cRank = C->rankOf(); + + const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); + + std::vector aBatchDims, bBatchDims, cBatchDims; + + if(aRank > 2) + aBatchDims = ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}); + if(bRank > 2) + bBatchDims = ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}); + if(cRank > 2) + cBatchDims = ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}); + + // BUILD_TRIPLE_SELECTOR(A->dataType(), B->dataType(), C->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (A, B, C, aBatchDims.data(), bBatchDims.data(), cBatchDims.data(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES); + + return C; +} + +/* +////////////////////////////////////////////////////////////////////////// +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + + // multiplication + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); + const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); + std::vector idxRanges(2 * C->rankOf()); + +// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) + for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { + + ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); + NDArray cSubArr = (*C)(idxRanges); + + if(aRank > bRank) { + NDArray aSubArr = (*A)(idxRanges); + mmulMxM(&aSubArr, B, &cSubArr, 1., 0., outOrder); + } + else if(bRank > aRank) { + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(A, &bSubArr, &cSubArr, 1., 0, outOrder); + } + else { + NDArray aSubArr = (*A)(idxRanges); + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., outOrder); + } + } + + return C; +} + +////////////////////////////////////////////////////////////////////////////// +// MXK x KxN = MxN +template +static void usualGemm(const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { + + T1* A = reinterpret_cast(const_cast(vA)); + T2* B = reinterpret_cast(const_cast(vB)); + T3* C = reinterpret_cast(vC); + T3 alphaZ(alpha), betaZ(beta); + + const bool flagC = cOrder == 'f'; + const bool flagA = (flagC && transA) || (!flagC && !transA); + const bool flagB = (flagC && transB) || (!flagC && !transB); + + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + // for(uint row = 0; row < M; ++row) { + + // T3* c = flagC ? (C + row) : (C + row * ldc); + + // for(uint col = 0; col < N; ++col) + // c[flagC ? col * ldc : col] = 0; + + // for(uint i = 0; i < K; ++i) { + + // T3* b = flagB ? (B + i * ldb) : (B + i); + // T3* a = flagA ? (A + row * lda + i) : (A + row + i * lda); + + // if(flagC) { + // for(uint col = 0; col < N; ++col) { + // if(betaZ) + // c[col * ldc] += a * b[flagB ? col : col * ldb] + betaZ * c[col * ldc]; + // else + // c[col * ldc] += a * b[flagB ? col : col * ldb]; + // } + // } + // else { + // for(uint col = 0; col < N; ++col) { + // if(betaZ) + // c[col] += a * b[flagB ? col : col * ldb] + betaZ * c[col]; + // else + // c[col] += a * b[flagB ? col : col * ldb]; + // } + // } + // } + // } + + auto func = PRAGMA_THREADS_FOR_2D { ; + for (auto row = start_x; row < stop_x; row += inc_x) { + for (auto col = start_y; col < stop_y; col += inc_y) { + T3 *c = flagC ? (C + row + col * ldc) : (C + row * ldc + col); + T3 val = 0; + + for (uint i = 0; i < K; ++i) { + T3 a = flagA ? *(A + row * lda + i) : *(A + row + i * lda); + T3 b = flagB ? *(B + col + i * ldb) : *(B + col * ldb + i); + val += alphaZ * a * b; + } + + if (betaZ) + *c = val + betaZ * *c; + else + *c = val; + } + } + }; + + samediff::Threads::parallel_tad(func, 0, M, 1, 0, N, 1); +} + +////////////////////////////////////////////////////////////////////////////// +// MXN x N = M +template +static void usualGemv(const char aOrder, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vX, const int incx, const double beta, void* vY, const int incy) { + + T1* A = reinterpret_cast(const_cast(vA)); + T2* X = reinterpret_cast(const_cast(vX)); + T3* Y = reinterpret_cast(vY); + T3 alphaZ(alpha), betaZ(beta); + + const bool flagA = aOrder == 'f'; + + auto func = PRAGMA_THREADS_FOR { + for (auto row = start; row < stop; row += increment) { + + T3 *y = Y + row * incy; + T3 val = 0; + + for (int i = 0; i < N; ++i) { + T3 a = flagA ? *(A + row + i * lda) : *(A + row * lda + i); + T3 x = *(X + i * incx); + val += alphaZ * a * x; + } + + if (betaZ) + *y = val + betaZ * *y; + else + *y = val; + } + }; + + samediff::Threads::parallel_tad(func, 0, M); +} +*/ + //BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index 32394f705..b9cdc00ad 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -22,52 +23,640 @@ #include #include "../MmulHelper.h" #include +#include +#include +#include namespace nd4j { - ////////////////////////////////////////////////////////////////////////////// -// MXK x KxN = MxN -// C array must be in f order +// MXK x KxN = MxN -> actual sequence of axes doesn't matter template -static __global__ void usualCudaGemm(const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +static __global__ void usualCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { - T1* A = reinterpret_cast(const_cast(vA)); - T2* B = reinterpret_cast(const_cast(vB)); - T3* C = reinterpret_cast(vC); + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast< T3*>(vC); + __shared__ int K; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads, *coords; __shared__ T3 alphaZ, betaZ; - __shared__ Nd4jLong strideArow, strideAcol, strideBrow, strideBcol; - const int row = blockIdx.y * blockDim.y + threadIdx.y; - const int col = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.x == 0) { - if(row == 0 && col == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); + + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + + betaPresent = beta; + + totalThreads = gridDim.x * blockDim.x; alphaZ = alpha; betaZ = beta; - - if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } - if(transB) { strideBrow = ldb; strideBcol = 1; } else { strideBrow = 1; strideBcol = ldb; } } - __syncthreads(); - T3 val = 0; - if (row < M && col < N) - for (int i = 0; i < K; i++) - val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow + col * strideBcol]; + auto aCoords = coords + threadIdx.x * 6; // 6 = (aRank + bRank + cRank) + auto bCoords = aCoords + 2; + auto cCoords = bCoords + 2; - C[row + col * ldc] = alphaZ * val + betaZ * C[row + col * ldc]; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); + + // evaluate A coordinates + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords); + + if(betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } } //////////////////////////////////////////////////////////////////////// template -__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { +__host__ static void usualGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { - usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); + usualCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); } +//////////////////////////////////////////////////////////////////////// +// MXN x N = M -> actual sequence of {M,N} axes doesn't matter +template +static __global__ void usualCudaGemv(const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, + const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { + + const T1* A = reinterpret_cast(vA); + const T2* X = reinterpret_cast(vX); + T3* Y = reinterpret_cast< T3*>(vY); + + __shared__ int M, N; + __shared__ bool betaPresent; + __shared__ Nd4jLong cLen, totalThreads, aNstride, aMstride; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + + N = shape::length(xShapeInfo); + M = shape::length(yShapeInfo); + + aMstride = shape::stride(aShapeInfo)[aMaxis]; + aNstride = shape::stride(aShapeInfo)[aMaxis == 0 ? 1 : 0]; + + totalThreads = gridDim.x * blockDim.x; + + betaPresent = beta; + + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); + + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < M; i += totalThreads) { + + // evaluate offsets + auto aOffset = i * aMstride; + auto xOffset = 0; + + T3 val = A[aOffset] * X[xOffset]; // first iteration + + for (uint j = 1; j < N; ++j) { // rest iterations + aOffset += aNstride; + xOffset += incx; + val = val + A[aOffset] * X[xOffset]; + } + + auto yOffset = i * incy; + + if(betaPresent) + Y[yOffset] = alphaZ * val + betaZ * Y[yOffset]; + else + Y[yOffset] = alphaZ * val; + } +} + +//////////////////////////////////////////////////////////////////////// +template +__host__ static void usualGemv(const int blocksPerGrid, const int threadsPerBlock, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vX, const Nd4jLong* xShapeInfo, void* vY, const Nd4jLong* yShapeInfo, const int incx, const int incy, const int aMaxis, const double alpha, const double beta) { + + usualCudaGemv<<>>(vA, aShapeInfo, vX, xShapeInfo, vY, yShapeInfo, incx, incy, aMaxis, alpha, beta); +} + + +////////////////////////////////////////////////////////////////////////////// +template +static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { + + T1* X = reinterpret_cast(const_cast(vX)); + T2* Y = reinterpret_cast(const_cast(vY)); + T3* Z = reinterpret_cast(vZ); + + extern __shared__ unsigned char shmem[]; + auto pairwiseMul = reinterpret_cast(shmem); + + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + if(tid < length) + pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; + + __syncthreads(); + + if(tid == 0) { + T3 sum = 0; + for(Nd4jLong i = 0; i < length; ++i) + sum = sum + pairwiseMul[i]; + + if(beta) + *Z = (T3)alpha * sum + (T3)beta * *Z; + else + *Z = (T3)alpha * sum; + } +} + +//////////////////////////////////////////////////////////////////////// +template +__host__ static void usualDot(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { + + usualCudaDot<<>>(length, alpha, vX, incx, vY, incy, beta, vZ); +} + +////////////////////////////////////////////////////////////////////////////// +// MXK x KxN = MxN +NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, double alpha, double beta, const char outOrder) { + + if(A->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); + if(B->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); + + const auto M = A->sizeAt(0); + const auto K = A->sizeAt(1); + const auto N = B->sizeAt(1); + + if(C != nullptr && C->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); + if(B->sizeAt(0) != K) + throw std::runtime_error("MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); + if(C != nullptr && C->sizeAt(0) != M) + throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); + if(C != nullptr && C->sizeAt(1) != N) + throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); + + if(C == nullptr) + C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + + const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first(); + + const auto aType = A->dataType(); + const auto bType = B->dataType(); + const auto cType = C->dataType(); + + const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); + + const bool typeDouble = ABC && aType == DataType::DOUBLE; + const bool typeFloat = ABC && aType == DataType::FLOAT32; + const bool typeHalf = ABC && aType == DataType::HALF && major >= 6; + const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; + const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; + + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + if(!typeDouble && !typeFloat && !typeHalf && !typeIntFloat && !typeHalfFloat) { + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 6 + 128; // 6 = aRank + bRank + cRank + + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), 0, 1, 0, 1, 0, 1, alpha, beta), NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); + } + else { + + std::vector toDelete; + + NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aKcont = K == 1 || A->strideAt(1) == 1; + bool bKcont = K == 1 || B->strideAt(0) == 1; + bool bNcont = N == 1 || B->strideAt(1) == 1; + bool cMcont = M == 1 || C->strideAt(0) == 1; + bool cNcont = N == 1 || C->strideAt(1) == 1; + + if(!aMcont && !aKcont) { + pA = A->dup('f'); + toDelete.push_back(pA); + aMcont = true; + } + if(!bKcont && !bNcont) { + pB = B->dup('f'); + toDelete.push_back(pB); + bKcont = true; + } + if(!cMcont) { + pC = C->dup('f'); + toDelete.push_back(pC); + cMcont = true; + } + + const bool transA = !aMcont; + const bool transB = !bKcont; + + const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); + const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(0) : pB->strideAt(1); + const int ldc = (cMcont && cNcont) ? M : pC->strideAt(1); + + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + NDArray::prepareSpecialUse({pC}, {pA, pB}); + + // choose appropriate cuda gemm api depending on data types + if(typeDouble) { + status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)pB->getSpecialBuffer(), ldb, &beta, (double*)pC->getSpecialBuffer(), ldc); + } + else if(typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); + } + else if(typeHalf) { + float16 alphaH(alpha), betaH(beta); + status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); + } + else if(typeIntFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); + } + else if(typeHalfFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); + } + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + + NDArray::registerSpecialUse({pC}, {pA, pB}); + + if(C != pC) + C->assign(pC); + + for(int i = toDelete.size() - 1; i >= 0; --i) + delete toDelete[i]; + } + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + + return C; +} + +//////////////////////////////////////////////////////////////////////////// +// MXN x N = M +NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { + + int xLenDim, yLenDim(0); + + if(A->rankOf() != 2) + throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); + if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !"); + + const auto M = A->sizeAt(0); + const auto N = A->sizeAt(1); + + if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !"); + if(X->lengthOf() != N) + throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !"); + if(Y != nullptr && Y->lengthOf() != M) + throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array has wrong length !"); + + if(Y == nullptr) + Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); + + const int incx = X->strideAt(xLenDim); + const int incy = Y->strideAt(yLenDim); + + const auto aType = A->dataType(); + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + + const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); + + const bool typeDouble = AXY && aType == DataType::DOUBLE; + const bool typeFloat = AXY && aType == DataType::FLOAT32; + + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + + if(!typeDouble && !typeFloat) { + + const int threadsPerBlock = MAX_NUM_THREADS; + const int blocksPerGrid = (M + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({Y}, {A, X}); + // BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), X->getSpecialBuffer(), X->getSpecialShapeInfo(), Y->getSpecialBuffer(), Y->getSpecialShapeInfo(), incx, incy, 0, alpha, beta), NUMERIC_TYPES) + NDArray::registerSpecialUse({Y}, {A, X}); + + } + else { + + NDArray *pA(const_cast(A)); + + bool aMcont = M == 1 || A->strideAt(0) == 1; + bool aNcont = N == 1 || A->strideAt(1) == 1; + + if(!aMcont && !aNcont) { + pA = A->dup('f'); + aMcont = true; + } + + const bool transA = !aMcont; + + const int lda = (aMcont && aNcont) ? M : transA ? pA->strideAt(0) : pA->strideAt(1); + + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + + NDArray::prepareSpecialUse({Y}, {pA, X}); + + // choose appropriate cuda gemm api depending on data types + if(typeDouble) { + status = cublasDgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)X->getSpecialBuffer(), incx, &beta, (double*)Y->getSpecialBuffer(), incy); + } + else if(typeFloat) { + float alphaF(alpha), betaF(beta); + status = cublasSgemv(*handle, transAblas, transA ? N : M, transA ? M : N, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)X->getSpecialBuffer(), incx, &betaF, (float*)Y->getSpecialBuffer(), incy); + } + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); + + NDArray::registerSpecialUse({Y}, {pA, X}); + + if(pA != A) + delete pA; + } + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); + + return Y; +} + +//////////////////////////////////////////////////////////////////////////// +// (X * Y) = Z[0] +NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) { + + int xLenDim(0), yLenDim(0); + + if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); + if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) + throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); + if(Z != nullptr && !Z->isScalar()) + throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); + + const auto length = X->lengthOf(); + + if(Y->lengthOf() != length) + throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); + + if(Z == nullptr) + Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); + + const Nd4jLong incx = X->strideAt(xLenDim); + const Nd4jLong incy = Y->strideAt(yLenDim); + + const auto xType = X->dataType(); + const auto yType = Y->dataType(); + const auto zType = Z->dataType(); + + if(!X->isActualOnDeviceSide()) X->syncToDevice(); + if(!Y->isActualOnDeviceSide()) Y->syncToDevice(); + if(!Z->isActualOnDeviceSide()) Z->syncToDevice(); + + cudaStream_t* stream = X->getContext()->getCudaStream(); + + dim3 threadsPerBlock(512); + dim3 blocksPerGrid(1); + if (length > 512) + threadsPerBlock.x = math::nd4j_ceil(static_cast(length) / 512); + + NDArray::prepareSpecialUse({Z}, {X, Y}); + + //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); + + NDArray::registerSpecialUse({Z}, {X, Y}); + + return Z; +} + +////////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +template +static __global__ void batchedCudaGemm(const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, + const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, + const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, + const double alpha, const double beta) { + + const T1* A = reinterpret_cast(vA); + const T2* B = reinterpret_cast(vB); + T3* C = reinterpret_cast< T3*>(vC); + + __shared__ bool betaPresent; + __shared__ int aRank, bRank, cRank, K; + __shared__ Nd4jLong cLen, totalThreads, *coords; + __shared__ T3 alphaZ, betaZ; + + if (threadIdx.x == 0) { + + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + cLen = shape::length(cShapeInfo); + + K = shape::shapeOf(const_cast(aShapeInfo))[aKaxis]; + + totalThreads = gridDim.x * blockDim.x; + aRank = shape::rank(aShapeInfo); + bRank = shape::rank(bShapeInfo); + cRank = shape::rank(cShapeInfo); + + betaPresent = beta; + + alphaZ = alpha; + betaZ = beta; + } + __syncthreads(); + + auto aCoords = coords + threadIdx.x * (aRank + bRank + cRank); + auto bCoords = aCoords + aRank; + auto cCoords = bCoords + bRank; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < cLen; i += totalThreads) { + + // evaluate C coordinates + shape::index2coords(i, cShapeInfo, cCoords); + + // calculate index of current batch + Nd4jLong batchInd; + if(cBatchDims != nullptr) + batchInd = shape::coords2index(cShapeInfo, cCoords, cRank - 2, cBatchDims); + + // evaluate A coordinates + if(aBatchDims != nullptr) + shape::index2coords(batchInd, aShapeInfo, aCoords, aRank - 2, aBatchDims); + aCoords[aMaxis] = cCoords[cMaxis]; + aCoords[aKaxis] = 0; + + // evaluate B coordinates + if(bBatchDims != nullptr) + shape::index2coords(batchInd, bShapeInfo, bCoords, bRank - 2, bBatchDims); + bCoords[bKaxis] = 0; + bCoords[bNaxis] = cCoords[cNaxis]; + + auto aOffset = shape::getOffset(aShapeInfo, aCoords); + auto bOffset = shape::getOffset(bShapeInfo, bCoords); + + T3 val = A[aOffset] * B[bOffset]; // first iteration + + for (uint j = 1; j < K; ++j) { // rest iterations + aOffset += shape::stride(aShapeInfo)[aKaxis]; + bOffset += shape::stride(bShapeInfo)[bKaxis]; + val = val + A[aOffset] * B[bOffset]; + } + + auto cOffset = shape::getOffset(cShapeInfo, cCoords); + + if(betaPresent) + C[cOffset] = alphaZ * val + betaZ * C[cOffset]; + else + C[cOffset] = alphaZ * val; + } +} + +//////////////////////////////////////////////////////////////////////// +template +__host__ static void batchedGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, cudaStream_t *stream, const void* vA, const Nd4jLong* aShapeInfo, const void* vB, const Nd4jLong* bShapeInfo, void* vC, const Nd4jLong* cShapeInfo, const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis, const int cNaxis, const double alpha, const double beta) { + + batchedCudaGemm<<>>(vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta); +} + +/////////////////////////////////////////////////////////////////// +NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + + const int cRank = C->rankOf(); + + const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); + + const int threadsPerBlock = MAX_NUM_THREADS / 8; + const int blocksPerGrid = (C->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * (aRank + bRank + cRank) + 128; + + PointersManager manager(A->getContext(), "MmulHelper::mmulNxN"); + + const int *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr); + + if(aRank > 2) + aBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}).data(), (aRank - 2) * sizeof(int))); + if(bRank > 2) + bBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}).data(), (bRank - 2) * sizeof(int))); + if(cRank > 2) + cBatchDims = reinterpret_cast(manager.replicatePointer(ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int))); + + NDArray::prepareSpecialUse({C}, {A, B}); + // BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(A->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->getSpecialBuffer(), A->getSpecialShapeInfo(), B->getSpecialBuffer(), B->getSpecialShapeInfo(), C->getSpecialBuffer(), C->getSpecialShapeInfo(), aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha, beta), NUMERIC_TYPES) + NDArray::registerSpecialUse({C}, {A, B}); + + manager.synchronize(); + + return C; +} + + +/* ////////////////////////////////////////////////////////////////////////////// // MXN x N = M template @@ -106,309 +695,331 @@ __host__ static void usualGemv(const dim3 &blocksPerGrid, const dim3 &threadsPer usualCudaGemv<<>>(transA, M, N, alpha, vA, lda, vX, incx, beta, vY, incy); } - +*/ +/* ////////////////////////////////////////////////////////////////////////////// +MXK x KxN = MxN +C array must be in f order template -static __global__ void usualCudaDot(const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { +static __global__ void usualCudaGemm(const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { - T1* X = reinterpret_cast(const_cast(vX)); - T2* Y = reinterpret_cast(const_cast(vY)); - T3* Z = reinterpret_cast(vZ); + T1* A = reinterpret_cast(const_cast(vA)); + T2* B = reinterpret_cast(const_cast(vB)); + T3* C = reinterpret_cast(vC); - extern __shared__ char shmem[]; - auto pairwiseMul = reinterpret_cast(shmem); + __shared__ T3 alphaZ, betaZ; + __shared__ Nd4jLong strideArow, strideAcol, strideBrow, strideBcol; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - if(tid < length) - pairwiseMul[tid] = X[tid * incx] * Y[tid * incy]; + const int row = blockIdx.y * blockDim.y + threadIdx.y; + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + if(row == 0 && col == 0) { + + alphaZ = alpha; + betaZ = beta; + + if(transA) { strideArow = lda; strideAcol = 1; } else { strideArow = 1; strideAcol = lda; } + if(transB) { strideBrow = ldb; strideBcol = 1; } else { strideBrow = 1; strideBcol = ldb; } + } __syncthreads(); - if(tid == 0) { - T3 sum = 0; - for(Nd4jLong i = 0; i < length; ++i) - sum = sum + pairwiseMul[i]; - *Z = (T3)alpha * sum + (T3)beta * *Z; - } -} + T3 val = 0; + if (row < M && col < N) + for (int i = 0; i < K; i++) + val = val + A[row * strideArow + i * strideAcol] * B[i * strideBrow + col * strideBcol]; -//////////////////////////////////////////////////////////////////////// -template -__host__ static void usualDot(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ) { - - usualCudaDot<<>>(length, alpha, vX, incx, vY, incy, beta, vZ); + C[row + col * ldc] = alphaZ * val + betaZ * C[row + col * ldc]; } ////////////////////////////////////////////////////////////////////////////// -// MXK x KxN = MxN -NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, double alpha, double beta, const char outOrder) { +template +__host__ static void usualGemm(const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc) { - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of A array is not equal 2 !"); - if(B->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of B array is not equal 2 !"); + usualCudaGemm<<>>(transA, transB, M, N, K, alpha, vA, lda, vB, ldb, beta, vC, ldc); +} +*/ +////////////////////////////////////////////////////////////////////////// +/* +NDArray* MmulHelper::mmulNxNold1(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - auto M = A->sizeAt(0); - auto K = A->sizeAt(1); - auto N = B->sizeAt(1); + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); - if(C != nullptr && C->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxM cuda: rank of C array is not equal 2 !"); - if(B->sizeAt(0) != K) - throw std::runtime_error("MmulHelper::mmulMxM cuda: B array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(0) != M) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of rows !"); - if(C != nullptr && C->sizeAt(1) != N) - throw std::runtime_error("MmulHelper::mmulMxM cuda: C array has wrong number of columns !"); + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } - if(C == nullptr) - C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else { + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + } + + + // multiplication + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); + const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); + std::vector idxRanges(2 * C->rankOf()); + +// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) + for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { + + ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); + NDArray cSubArr = (*C)(idxRanges); + + if(aRank > bRank) { + NDArray aSubArr = (*A)(idxRanges); + mmulMxM(&aSubArr, B, &cSubArr, 1., 0., outOrder); + } + else if(bRank > aRank) { + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(A, &bSubArr, &cSubArr, 1., 0, outOrder); + } + else { + NDArray aSubArr = (*A)(idxRanges); + NDArray bSubArr = (*B)(idxRanges); + mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., outOrder); + } + } + + return C; +} +*/ + +////////////////////////////////////////////////////////////////////////// +// [bS,M,K] x [bS,K,N] = [bS,M,N] +// [bS,M,K] x [K,N] = [bS,M,N] +// [M,K] x [bS,K,N] = [bS,M,N] +// bS could stand for several axes +/* +NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { + + const int aRank = A->rankOf(); + const int bRank = B->rankOf(); + + // input ranks validation + if(aRank > bRank && bRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); + else if(bRank > aRank && aRank != 2) + throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); + else if (aRank == bRank ) { + for(int i = 0; i < aRank - 2; ++i) + if(A->sizeAt(i) != B->sizeAt(i)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + } + + if(A->sizeAt(-1) != B->sizeAt(-2)) + throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); + + // validation of C array + std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); + cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); + cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); + + if(C != nullptr ) { + if(!C->isSameShape(cExpectedShape)) + throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); + } + else + C = new NDArray(outOrder, cExpectedShape, B->dataType()); + + const int cRank = C->rankOf(); + + const auto M = A->sizeAt(-2); + const auto K = A->sizeAt(-1); + const auto N = B->sizeAt(-1); NDArray *pA(const_cast(A)), *pB(const_cast(B)), *pC(const_cast(C)); std::vector toDelete; - if(A->ews() != 1) { - pA = pA->dup('f'); + bool aMcont = M == 1 || A->strideAt(-2) == 1; + bool aKcont = K == 1 || A->strideAt(-1) == 1; + bool bKcont = K == 1 || B->strideAt(-2) == 1; + bool bNcont = N == 1 || B->strideAt(-1) == 1; + bool cMcont = M == 1 || C->strideAt(-2) == 1; + bool cNcont = N == 1 || C->strideAt(-1) == 1; + + if(!aMcont && !aKcont) { + pA = A->dup('c'); toDelete.push_back(pA); + aKcont = true; } - if(B->ews() != 1) { - pB = pB->dup('f'); + if(!bKcont && !bNcont) { + pB = B->dup('c'); toDelete.push_back(pB); + bNcont = true; } - if(C->ews() != 1) { - pC = pC->dup('f'); + std::vector permut(cRank); + if(!cMcont) { + std::iota(permut.begin(), permut.end(), 0); + permut[cRank - 2] = cRank - 1; + permut[cRank - 1] = cRank - 2; // swap two last dimensions [..., M,N] -> [..., N,M] + auto Cpermut = C->permute(permut); + pC = new NDArray('c', Cpermut.getShapeAsVector(), Cpermut.dataType(), A->getContext()); + pC->assign(Cpermut); toDelete.push_back(pC); + cMcont = true; } - if(pC->ordering() != 'f') { - auto temp = pA; - pA = new NDArray(pB ->permute({1,0})); - pB = new NDArray(temp->permute({1,0})); - pC = new NDArray(pC ->permute({1,0})); - toDelete.push_back(pA); - toDelete.push_back(pB); - toDelete.push_back(pC); - M = pA->sizeAt(0); - K = pA->sizeAt(1); - N = pB->sizeAt(1); - } - - const auto aOrder = pA->ordering(); - const auto bOrder = pB->ordering(); - - const bool transA = aOrder != 'f'; - const bool transB = bOrder != 'f'; - - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - const int lda = aOrder == 'f' ? M : K; - const int ldb = bOrder == 'f' ? K : N; - const int ldc = M; // cOrder == 'f' ? M : N; const auto aType = pA->dataType(); const auto bType = pB->dataType(); const auto cType = pC->dataType(); - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); - const bool AB(aType == bType), AC(aType == cType), ABC(AB && AC); - const int deviceId = AffinityManager::currentDeviceId(); - const int major = Environment::getInstance()->capabilities()[deviceId].first(); + bool badTypes = false; + cudaDataType_t cudaType, cudaAType, cudaBType, cudaCType; + + if(ABC && aType == DataType::HALF) { + cudaType = cudaAType = cudaBType = cudaCType = CUDA_R_16F; + } + else if(ABC && aType == DataType::FLOAT32) { + cudaType = cudaAType = cudaBType = cudaCType = CUDA_R_32F; + } + else if(ABC && aType == DataType::DOUBLE) { + cudaType = cudaAType = cudaBType = cudaCType = CUDA_R_64F; + } + else if(AB && cType == DataType::FLOAT32 && aType == DataType::INT8) { + cudaType = cudaCType = CUDA_R_32F; + cudaAType = cudaBType = CUDA_R_8I; + } + else if(AB && cType == DataType::FLOAT32 && aType == DataType::HALF) { + cudaType = cudaCType = CUDA_R_32F; + cudaAType = cudaBType = CUDA_R_16F; + } + else + badTypes = true; + + const int bS = pC->lengthOf() / (M*N); + + const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(cRank, {-2, -1}); NDArray::prepareSpecialUse({pC}, {pA, pB}); - // choose appropriate cuda gemm api depending on data types - if(ABC && aType == DataType::DOUBLE) { - status = cublasDgemm(*handle, transAblas, transBblas, M, N, K, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)pB->getSpecialBuffer(), ldb, &beta, (double*)pC->getSpecialBuffer(), ldc); - } - else if(ABC && aType == DataType::FLOAT32) { - float alphaF(alpha), betaF(beta); - status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc); - } - else if(ABC && aType == DataType::HALF && major >= 6) { - float16 alphaH(alpha), betaH(beta); - status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc); - } - else if(AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_8I, lda, pB->getSpecialBuffer(), CUDA_R_8I, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); - } - else if(AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6) { - float alphaF(alpha), betaF(beta); - status = cublasSgemmEx(*handle, transAblas, transBblas, M, N, K, &alphaF, pA->getSpecialBuffer(), CUDA_R_16F, lda, pB->getSpecialBuffer(), CUDA_R_16F, ldb, &betaF, pC->getSpecialBuffer(), CUDA_R_32F, ldc); - } - else { - dim3 threadsPerBlock(N, M); - dim3 blocksPerGrid(1, 1); - if (M*N > 512){ - threadsPerBlock.x = threadsPerBlock.y = 512; - blocksPerGrid.x = math::nd4j_ceil(static_cast(N) / threadsPerBlock.x); // cols - blocksPerGrid.y = math::nd4j_ceil(static_cast(M) / threadsPerBlock.y); // rows + if(!badTypes) { + + std::vector subArrOffsets(bS); + std::vector subArrShapeInfo(shape::shapeInfoLength(2)); // all sub-arrays have rank = 2 + + std::vector aSubArrs(bS), bSubArrs(bS), cSubArrs(bS); + + if(aRank > 2) + shape::calcSubArrShapeAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); + for (int i = 0; i < bS; ++i) + aSubArrs[i] = aRank == 2 ? pA->getSpecialBuffer() : pA->getSpecialBuffer() + subArrOffsets[i] * pA->sizeOfT(); + + if(bRank > 2) + shape::calcSubArrShapeAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); + for (int i = 0; i < bS; ++i) + bSubArrs[i] = bRank == 2 ? pB->getSpecialBuffer() : pB->getSpecialBuffer() + subArrOffsets[i] * pB->sizeOfT(); + + shape::calcSubArrShapeAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data()); + for (int i = 0; i < bS; ++i) + cSubArrs[i] = pC->getSpecialBuffer() + subArrOffsets[i] * pC->sizeOfT(); + + PointersManager manager(A->getContext(), "mmulNxN"); + + const void** aSubArrsCuda = reinterpret_cast(manager.replicatePointer(aSubArrs.data(), aSubArrs.size() * sizeof(void*))); + const void** bSubArrsCuda = reinterpret_cast(manager.replicatePointer(bSubArrs.data(), bSubArrs.size() * sizeof(void*))); + void** cSubArrsCuda = reinterpret_cast< void **>(manager.replicatePointer(cSubArrs.data(), cSubArrs.size() * sizeof(void*))); + + const bool transA = !aMcont; + const bool transB = !bKcont; + + const int lda = (aMcont && aKcont) ? M : transA ? pA->strideAt(-2) : pA->strideAt(-1); + const int ldb = (bKcont && bNcont) ? K : transB ? pB->strideAt(-2) : pB->strideAt(-1); + const int ldc = (cMcont && cNcont) ? M : C != pC ? pC->strideAt(-2) : pC->strideAt(-1); + + const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t transBblas = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + union Coeff {__half _h; float _f; double _d; }; + Coeff uAlpha, uBeta; + + if(cudaType == CUDA_R_16F) { + uAlpha._h = alpha; + uBeta._h = beta; + } + else if(cudaType == CUDA_R_32F) { + uAlpha._f = alpha; + uBeta._f = beta; + } + else if(cudaType == CUDA_R_64F) { + uAlpha._d = alpha; + uBeta._d = beta; } - //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES) + auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); + auto stream = A->getContext()->getCudaStream(); + + auto status = cublasSetStream_v2(*handle, *stream); + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + + status = cublasGemmBatchedEx(*handle, transAblas, transBblas, M, N, K, &uAlpha, aSubArrsCuda, cudaAType, lda, bSubArrsCuda, cudaBType, ldb, &uBeta, cSubArrsCuda, cudaCType, ldc, bS, cudaType, CUBLAS_GEMM_DEFAULT); + + if (status != CUBLAS_STATUS_SUCCESS) + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", status); + + auto cudaResult = cudaStreamSynchronize(*stream); + if (cudaResult != 0) + throw cuda_exception::build("MmulHelper::mmulNxN cuda failed !", cudaResult); } + else { - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); + std::vector idxRanges(2 * pC->rankOf()); - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", cudaResult); + for(Nd4jLong i = 0; i < bS; ++i) { + + ShapeUtils::evalIdxRangesForSubArr(i, pC->getShapeInfo(), dimsToExclude, idxRanges.data()); + NDArray cSubArr = (*pC)(idxRanges); + + if(aRank > bRank) { + NDArray aSubArr = (*pA)(idxRanges); + mmulMxM(&aSubArr, pB, &cSubArr, 1., 0., pC->ordering()); + } + else if(bRank > aRank) { + NDArray bSubArr = (*pB)(idxRanges); + mmulMxM(pA, &bSubArr, &cSubArr, 1., 0, pC->ordering()); + } + else { + NDArray aSubArr = (*pA)(idxRanges); + NDArray bSubArr = (*pB)(idxRanges); + mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., pC->ordering()); + } + } + } NDArray::registerSpecialUse({pC}, {pA, pB}); - if(C->ews() != 1) - C->assign(pC); + if(C != pC) + C->assign(pC->permute(permut)); for(int i = toDelete.size() - 1; i >= 0; --i) delete toDelete[i]; return C; } - -//////////////////////////////////////////////////////////////////////////// -// MXN x N = M -NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* Y, const double alpha, const double beta, const char outOrder) { - - int xLenDim, yLenDim(0); - - if(A->rankOf() != 2) - throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !"); - if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X array must be vector !"); - - const auto M = A->sizeAt(0); - const auto N = A->sizeAt(1); - - if(Y != nullptr && !shape::isCommonVector(Y->getShapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array must be vector !"); - if(X->lengthOf() != N) - throw std::runtime_error("MmulHelper::mmulMxV cuda: X vector has wrong length !"); - if(Y != nullptr && Y->lengthOf() != M) - throw std::runtime_error("MmulHelper::mmulMxV cuda: Y array has wrong length !"); - - if(Y == nullptr) - Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); - - NDArray *pA(const_cast(A)); - - if(A->ews() != 1) - pA = pA->dup('f'); - - const bool transA = pA->ordering() == 'c'; - - const cublasOperation_t transAblas = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - - int lda, lta; - if(transA) { lda = N; lta = M; } - else { lda = M; lta = N; } - - const int incx = X->stridesOf()[xLenDim]; - const int incy = Y->stridesOf()[yLenDim]; - - const auto aType = pA->dataType(); - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - - auto handle = reinterpret_cast(A->getContext()->getCublasHandle()); - auto stream = A->getContext()->getCudaStream(); - - auto status = cublasSetStream_v2(*handle, *stream); - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - - const bool AX(aType == xType), AY(aType == yType), AXY(AX && AY); - - NDArray::prepareSpecialUse({Y}, {pA, X}); - - // choose appropriate cuda gemm api depending on data types - if(AXY && aType == DataType::DOUBLE) { - status = cublasDgemv(*handle, transAblas, lda, lta, &alpha, (double*)pA->getSpecialBuffer(), lda, (double*)X->getSpecialBuffer(), incx, &beta, (double*)Y->getSpecialBuffer(), incy); - } - else if(AXY && aType == DataType::FLOAT32) { - float alphaF(alpha), betaF(beta); - status = cublasSgemv(*handle, transAblas, lda, lta, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)X->getSpecialBuffer(), incx, &betaF, (float*)Y->getSpecialBuffer(), incy); - } - else { - dim3 threadsPerBlock(M); - dim3 blocksPerGrid(1); - if (M > 512){ - threadsPerBlock.x = 512; - blocksPerGrid.x = math::nd4j_ceil(static_cast(M) / threadsPerBlock.x); // rows - } - //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES) - } - - if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", cudaResult); - - NDArray::registerSpecialUse({Y}, {pA, X}); - - if(pA != A) - delete pA; - - return Y; -} - -//////////////////////////////////////////////////////////////////////////// -// (X * Y) = Z[0] -NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, const double alpha, const double beta) { - - int xLenDim(0), yLenDim(0); - - if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); - if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); - if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); - - const auto length = X->lengthOf(); - - if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); - - if(Z == nullptr) - Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); - - const Nd4jLong incx = X->stridesOf()[xLenDim]; - const Nd4jLong incy = Y->stridesOf()[yLenDim]; - - const auto xType = X->dataType(); - const auto yType = Y->dataType(); - const auto zType = Z->dataType(); - - if(!X->isActualOnDeviceSide()) X->syncToDevice(); - if(!Y->isActualOnDeviceSide()) Y->syncToDevice(); - if(!Z->isActualOnDeviceSide()) Z->syncToDevice(); - - cudaStream_t* stream = X->getContext()->getCudaStream(); - - dim3 threadsPerBlock(512); - dim3 blocksPerGrid(1); - if (length > 512) - threadsPerBlock.x = math::nd4j_ceil(static_cast(length) / 512); - - NDArray::prepareSpecialUse({Z}, {X, Y}); - - //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); - BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES) - - auto cudaResult = cudaStreamSynchronize(*stream); - if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); - - NDArray::registerSpecialUse({Z}, {X, Y}); - - return Z; -} +*/ //BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES); diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index b50104bee..ab97ad137 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -184,69 +184,6 @@ NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray #endif -////////////////////////////////////////////////////////////////////////// -NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { - - const int aRank = A->rankOf(); - const int bRank = B->rankOf(); - - // input ranks validation - if(aRank > bRank && bRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); - else if(bRank > aRank && aRank != 2) - throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); - else if (aRank == bRank ) { - for(int i = 0; i < aRank - 2; ++i) - if(A->sizeAt(i) != B->sizeAt(i)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - } - - if(A->sizeAt(-1) != B->sizeAt(-2)) - throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); - - // validation of C array - std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); - cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); - cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); - - if(C != nullptr ) { - if(!C->isSameShape(cExpectedShape)) - throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); - } - else { - C = new NDArray(outOrder, cExpectedShape, B->dataType()); - } - - - // multiplication - const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); - std::vector idxRanges(2 * C->rankOf()); - -// #pragma omp parallel for schedule(guided) firstprivate(idxRanges) - for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { - - ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); - NDArray cSubArr = (*C)(idxRanges); - - if(aRank > bRank) { - NDArray aSubArr = (*A)(idxRanges); - mmulMxM(&aSubArr, B, &cSubArr, 1., 0., outOrder); - } - else if(bRank > aRank) { - NDArray bSubArr = (*B)(idxRanges); - mmulMxM(A, &bSubArr, &cSubArr, 1., 0, outOrder); - } - else { - NDArray aSubArr = (*A)(idxRanges); - NDArray bSubArr = (*B)(idxRanges); - mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., outOrder); - } - } - - return C; -} - ////////////////////////////////////////////////////////////////////////// nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, nd4j::NDArray* C , const double alpha, const double beta, const char outOrder) { diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index b8cbf1b37..cdcddc92d 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -901,6 +901,10 @@ namespace shape { */ ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims); @@ -910,6 +914,10 @@ namespace shape { */ ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords); ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords, const int dimsSize, const int* tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -1762,6 +1770,19 @@ INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, return index; } +INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords, const int dimsSize, const int* tadDims) { + + Nd4jLong index, shift = 1;; + + index = coords[tadDims[dimsSize - 1]]; + for(uint i = dimsSize - 1; i >= 1; --i) { + shift *= shapeInfo[tadDims[i]]; + index += shift * coords[i - 1]; + } + + return index; +} + template INLINEDEF _CUDA_HD void fill(T* buffer, T value, Nd4jLong length) { @@ -3957,9 +3978,13 @@ INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, oldStart = oldStop++; } - newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order - newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews - newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type + // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank) + for (int i = newStart; i < newRank; ++i) + newStrides[i] = 1; + + newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order + newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews + newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type return true; } @@ -4705,6 +4730,16 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jL coords[0] = index; // last iteration } +////////////////////////////////////////////////////////////////////// +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims) { + + for(uint i = dimsSize - 1; i > 0; --i) { + coords[tadDims[i]] = index % shapeInfo[1 + tadDims[i]]; + index /= shapeInfo[1 + tadDims[i]]; + } + coords[tadDims[0]] = index; // last iteration +} + ////////////////////////////////////////////////////////////////////// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 273749bfd..b26702b25 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 8a9986e23..1a5a255ee 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index 1dc2c8e48..085127e74 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -1798,6 +1799,7 @@ TEST_F(HelpersTests1, tensordot_test_6) { // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] MmulHelper::tensorDot(&a, &b, &cR, {{1,0,4,5,2,3}, {iC,bS*oH*oW,kW*kH}}, {{2,0,1,3},{iC,kH*kW,mC}}, {{3,0,1,2,4},{iC, bS*oH*oW, mC}}); + // c.printBuffer(); ASSERT_TRUE(c.isSameShape(expected)); ASSERT_TRUE(c.equalsTo(expected)); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index 747ecc183..0f3cab509 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -1891,7 +1891,7 @@ TEST_F(NDArrayTest, TestMMulMultiDim) { ASSERT_TRUE(result->isSameShape(&expected)); //result->printShapeInfo("result shape"); - //result->printBuffer("result buffer"); + // result->printBuffer("result buffer"); ASSERT_TRUE(result->equalsTo(&expected)); delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp index 998b8164b..68c68aafb 100644 --- a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp @@ -61,8 +61,10 @@ public: TEST_F(PerformanceTests, test_maxpooling2d_1) { std::vector valuesX; - auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); - auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); + // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); + // auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); + auto x = NDArrayFactory::create('c', {8, 3, 64, 64}); + auto z = NDArrayFactory::create('c', {8, 3, 64, 64}); x.linspace(1.0f); Nd4jLong k = 5; diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index dfb685e22..1b99a99e6 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -274,4 +275,28 @@ TEST_F(PlaygroundTests, test_relubp_1) { nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw); } + +////////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, my) { + + int bS=1, iH=56,iW=56, iC=144,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=56,oW=56; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + input = 2.; + weights.linspace(0.1, 0.1); + + nd4j::ops::depthwise_conv2d op; + auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + + delete results; +} + */ + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java index de9edfc2e..73bb11bea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOpsHolder.java @@ -106,7 +106,7 @@ public class NativeOpsHolder { boolean logInit = Boolean.parseBoolean(logInitProperty); if(logInit) { - log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); + log.info("Number of threads used for linear algebra: {}", deviceNativeOps.ompGetMaxThreads()); } } catch (Exception | Error e) { throw new RuntimeException( diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index e8b5e15c9..e03937d0f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -4600,6 +4600,11 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("Nd4jLong") long sizeAt(int dim); + /** + * returns stride of "dim" dimension + */ + public native @Cast("Nd4jLong") long strideAt(int dim); + /** * returns order of array */ @@ -8019,6 +8024,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @@ -8032,6 +8043,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -9088,6 +9105,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index e2e9b0c2f..84f8b2c12 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -4600,6 +4600,11 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); */ public native @Cast("Nd4jLong") long sizeAt(int dim); + /** + * returns stride of "dim" dimension + */ + public native @Cast("Nd4jLong") long strideAt(int dim); + /** * returns order of array */ @@ -8019,6 +8024,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @@ -8032,6 +8043,12 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); + /** + * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! + */ + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -9088,6 +9105,8 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// +