diff --git a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp index 684d98d6d..8ab5fa32f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/layer_norm.cpp @@ -38,9 +38,13 @@ namespace ops { const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC const int dimC = isNCHW ? 1 : input->rankOf() - 1; + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + NDArray* bias = nullptr; - if (block.width() > 2) + if (block.width() > 2) { bias = INPUT_VARIABLE(2); + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); + } std::vector longAxis = ArrayUtils::toLongVector(axis); @@ -80,13 +84,16 @@ namespace ops { const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC const int dimC = isNCHW ? 1 : input->rankOf() - 1; + REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str()); + std::vector axis = *block.getIArguments(); std::vector longAxis = ArrayUtils::toLongVector(axis); if(bias != nullptr) { + REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str()); // eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); - eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true); + eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); } NDArray standardized(input->shapeInfo(), false, block.launchContext()); @@ -99,7 +106,7 @@ namespace ops { standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr); - standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true); + standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC})); nd4j::ops::standardize_bp standardizeBp; // eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index c07c9adf8..da4f5cc86 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -100,7 +100,7 @@ static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int thr BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES); ////////////////////////////////////////////////////////////////////////// -static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& VT, const bool fullUV, const bool calcUV) { +static void svdQR(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* VT, const bool fullUV, const bool calcUV) { // since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f' // we make this function to have deal with 2 valid cases only: @@ -113,59 +113,59 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND // U [m, m] or [m, n] if fullUV = false and m > n // VT [n, n] or [m, n] if fullUV = false and m < n - if(A.rankOf() != 2) + if(A->rankOf() != 2) throw std::runtime_error("svdQR: rank of A array is not equal 2 !"); - auto m = A.sizeAt(0); - auto n = A.sizeAt(1); + auto m = A->sizeAt(0); + auto n = A->sizeAt(1); const int minDim = m < n ? m : n; - const char orderA = A.ordering(); + const char orderA = A->ordering(); if(m < n) throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !"); - if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S)) + if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S)) throw std::runtime_error("svdQR: wrong shape of S array !"); if(calcUV) { - if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U)) + if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdQR: wrong shape of U array !"); - else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U)) + else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdQR: wrong shape of U array !"); - if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&VT)) + if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(VT)) throw std::runtime_error("svdQR: wrong shape of VT array !"); - else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(&VT)) + else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(VT)) throw std::runtime_error("svdQR: wrong shape of VT array !"); } - NDArray* pA = const_cast(&A); - NDArray* pS = &S; - NDArray* pU = &U; - NDArray* pVT = &VT; + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pVT = VT; std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A.dup('f'); + pA = A->dup('f'); toDelete.push_back(pA); } - if(S.ews() != 1) { - pS = S.dup('f'); + if(S->ews() != 1) { + pS = S->dup('f'); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U.dup('f'); + pU = U->dup('f'); toDelete.push_back(pU); } if(pVT->ews() != 1 || pVT->ordering() == 'c') { - pVT = VT.dup('f'); + pVT = VT->dup('f'); toDelete.push_back(pVT); } } @@ -183,9 +183,9 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND // query working space of SVD int lwork = 0; - if(A.dataType() == DataType::DOUBLE) + if(A->dataType() == DataType::DOUBLE) status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork); - else if(A.dataType() == DataType::FLOAT32) + else if(A->dataType() == DataType::FLOAT32) status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork); else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -195,7 +195,7 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND // allocate memory for dWork void* dWork = nullptr; - cudaError_t status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); + cudaError_t status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); if(status2 != cudaSuccess) throw cuda_exception::build("svdQR: cuda failed !", status2); @@ -226,11 +226,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND NDArray::prepareSpecialUse({pS, pU, pVT}, {pA}); // choose appropriate cuda gemm api depending on data types - if(A.dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pVT->getSpecialBuffer()), ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + if(A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } - else if(A.dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pVT->getSpecialBuffer()), ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + else if(A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -242,11 +242,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND NDArray::registerSpecialUse({pS, pU, pVT}, {pA}); - S.assign(pS); + S->assign(pS); if(calcUV) { - U.assign(pU); - VT.assign(pVT); + U->assign(pU); + VT->assign(pVT); } for (int i = toDelete.size() - 1; i >= 0; --i) @@ -266,62 +266,62 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND } ////////////////////////////////////////////////////////////////////////// -static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) { +static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { // A [m, n] // S [n] // U [m, m] or [m, n] if fullUV = false and m > n // V [n, n] or [n, m] if fullUV = false and m < n - if(A.rankOf() != 2) + if(A->rankOf() != 2) throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); - auto m = A.sizeAt(0); - auto n = A.sizeAt(1); + auto m = A->sizeAt(0); + auto n = A->sizeAt(1); const int minDim = m < n ? m : n; - if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S)) + if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S)) throw std::runtime_error("svdJcb: wrong shape of S array !"); if(calcUV) { - if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U)) + if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdJcb: wrong shape of U array !"); - else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U)) + else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U)) throw std::runtime_error("svdJcb: wrong shape of U array !"); - if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&V)) + if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(V)) throw std::runtime_error("svdJcb: wrong shape of V array !"); - else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(&V)) + else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(V)) throw std::runtime_error("svdJcb: wrong shape of V array !"); } - NDArray* pA = const_cast(&A); - NDArray* pS = &S; - NDArray* pU = &U; - NDArray* pV = &V; + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pV = V; std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A.dup('f'); + pA = A->dup('f'); toDelete.push_back(pA); } - if(S.ews() != 1) { - pS = S.dup('f'); + if(S->ews() != 1) { + pS = S->dup('f'); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U.dup('f'); + pU = U->dup('f'); toDelete.push_back(pU); } if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V.dup('f'); + pV = V->dup('f'); toDelete.push_back(pV); } } @@ -362,10 +362,10 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N // query working space of SVD int lwork = 0; - if(A.dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams); - else if(A.dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams); + if(A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); + else if(A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams); else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -374,7 +374,7 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N // allocate memory dWork void* dWork = nullptr; - auto status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); + auto status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); if(status2 != cudaSuccess) throw cuda_exception::build("svdJcb: cuda failed !", status2); @@ -383,11 +383,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); // choose appropriate cuda gemm api depending on data types - if(A.dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + if(A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } - else if(A.dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + else if(A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -399,11 +399,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - S.assign(pS); + S->assign(pS); if(calcUV) { - U.assign(pU); - V.assign(pV); + U->assign(pU); + V->assign(pV); } for (int i = toDelete.size() - 1; i >= 0; --i) @@ -422,67 +422,67 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N } ////////////////////////////////////////////////////////////////////////// -static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) { +static void svdBatched(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) { // A [..., m, n] // S [..., n] // U [..., m, m] or [..., m, n] if fullUV = false and m > n // V [..., n, n] or [..., n, m] if fullUV = false and m < n - auto m = A.sizeAt(-2); - auto n = A.sizeAt(-1); + auto m = A->sizeAt(-2); + auto n = A->sizeAt(-1); const int minDim = m < n ? m : n; - const Nd4jLong bS = A.lengthOf() / (m * n); + const Nd4jLong bS = A->lengthOf() / (m * n); if(m > 32 || n > 32) throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !"); - if(minDim != S.sizeAt(-1)) + if(minDim != S->sizeAt(-1)) throw std::runtime_error("svdBatched: wrong shape of S array !"); if(calcUV) { - if(U.sizeAt(-2) != m) + if(U->sizeAt(-2) != m) throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U.sizeAt(-1) != (fullUV ? m : minDim)) + if(U->sizeAt(-1) != (fullUV ? m : minDim)) throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(U.lengthOf() / (U.sizeAt(-2) * U.sizeAt(-1)) != bS) + if(U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS) throw std::runtime_error("svdBatched: wrong shape of U array !"); - if(V.sizeAt(-2) != n) + if(V->sizeAt(-2) != n) throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V.sizeAt(-1) != (fullUV ? n : minDim)) + if(V->sizeAt(-1) != (fullUV ? n : minDim)) throw std::runtime_error("svdBatched: wrong shape of V array !"); - if(V.lengthOf() / (V.sizeAt(-2) * V.sizeAt(-1)) != bS) + if(V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS) throw std::runtime_error("svdBatched: wrong shape of V array !"); } - NDArray* pA = const_cast(&A); - NDArray* pS = &S; - NDArray* pU = &U; - NDArray* pV = &V; + NDArray* pA = const_cast(A); + NDArray* pS = S; + NDArray* pU = U; + NDArray* pV = V; std::vector toDelete; if(pA->ews() != 1 || pA->ordering() == 'c') { - pA = A.dup('f'); + pA = A->dup('f'); toDelete.push_back(pA); } - if(S.ews() != 1) { - pS = S.dup('f'); + if(S->ews() != 1) { + pS = S->dup('f'); toDelete.push_back(pS); } if(calcUV) { if(pU->ews() != 1 || pU->ordering() == 'c') { - pU = U.dup('f'); + pU = U->dup('f'); toDelete.push_back(pU); } if(pV->ews() != 1 || pV->ordering() == 'c') { - pV = V.dup('f'); + pV = V->dup('f'); toDelete.push_back(pV); } } @@ -532,10 +532,10 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& // query working space of SVD int lwork = 0; - if(A.dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS); - else if(A.dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS); + if(A->dataType() == DataType::DOUBLE) + status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); + else if(A->dataType() == DataType::FLOAT32) + status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS); else throw std::invalid_argument("svdBatched: given data type is unsupported !"); @@ -544,7 +544,7 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& // allocate memory dWork void* dWork = nullptr; - status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); + status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork); if(status2 != cudaSuccess) throw cuda_exception::build("svdBatched: cuda failed !", status2); status2 = cudaDeviceSynchronize(); @@ -556,11 +556,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); // choose appropriate cuda gemm api depending on data types - if(A.dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + if(A->dataType() == DataType::DOUBLE) { + status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); } - else if(A.dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), reinterpret_cast(pU->getSpecialBuffer()), ldu, reinterpret_cast(pV->getSpecialBuffer()), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); + else if(A->dataType() == DataType::FLOAT32) { + status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams, bS); } else throw std::invalid_argument("svdBatched: given data type is unsupported !"); @@ -572,11 +572,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& NDArray::registerSpecialUse({pS, pU, pV}, {pA}); - S.assign(pS); + S->assign(pS); if(calcUV) { - U.assign(pU); - V.assign(pV); + U->assign(pU); + V->assign(pV); } for (int i = toDelete.size() - 1; i >= 0; --i) @@ -603,8 +603,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vectorrankOf() == 2) { - // svdQR(context, *x, *S, *U, VT, fullUV, calcUV); - svdJcb(context, *x, *S, *U, *V, fullUV, calcUV); + // svdQR(context, x, S, U, VT, fullUV, calcUV); + svdJcb(context, x, S, U, V, fullUV, calcUV); } else { @@ -621,7 +621,7 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vectorsize(); ++i) - svdJcb(context, *tadsX->at(i), *tadsS->at(i), calcUV ? *tadsU->at(i) : *S, calcUV ? *tadsV->at(i) : *S, fullUV, calcUV); + svdJcb(context, tadsX->at(i), tadsS->at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV); delete tadsX; delete tadsS; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 21a0381e9..65cb470a7 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -245,8 +245,8 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_3) { TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { auto x = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto g = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto b = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + auto g = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + auto b = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); nd4j::ops::layer_norm op; auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); @@ -256,8 +256,8 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { auto x = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto g = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); - auto b = NDArrayFactory::create('c', {1, 5}, {1., 2., 3., 4., 5.}); + auto g = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + auto b = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); auto eps = NDArrayFactory::create('c', {1, 5}, {0., 0., 0., 0., 0.}); nd4j::ops::layer_norm_bp op; @@ -266,6 +266,26 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_2) { + + NDArray x('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32); + NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, nd4j::DataType::FLOAT32); + NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, nd4j::DataType::FLOAT32); + NDArray gradO('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32); + + NDArray gradI('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32); + NDArray gradG('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradB('c', {4}, nd4j::DataType::FLOAT32); + + x.linspace(-20, 0.5); + gradO.linspace(-4, 0.05); + + nd4j::ops::layer_norm_bp op; + auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, {}, {1,2,3}, {true}); + ASSERT_EQ(Status::OK(), status); +} + TEST_F(DeclarableOpsTests15, test_hashCode_1) { auto x = NDArrayFactory::create('c', {10}); auto y = NDArrayFactory::create('c', {10}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 992b21c0f..cff84c69b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -87,4 +87,4 @@ TEST_F(DeclarableOpsTests16, test_size_dtype_1) { ASSERT_EQ(Status::OK(), status); ASSERT_EQ(e, z); -} \ No newline at end of file +}