Shyrma fix2 (#186)

* - further work on layer_norm

Signed-off-by: Yurii <yurii@skymind.io>

* - further work on layer_norm 2

Signed-off-by: Yurii <yurii@skymind.io>

* - correct helpers for svd cuda

Signed-off-by: Yurii <yurii@skymind.io>
master
Yurii Shyrma 2019-08-27 19:57:59 +03:00 committed by raver119
parent 650539528c
commit 2144941313
4 changed files with 129 additions and 102 deletions

View File

@ -38,9 +38,13 @@ namespace ops {
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
const int dimC = isNCHW ? 1 : input->rankOf() - 1; const int dimC = isNCHW ? 1 : input->rankOf() - 1;
REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str());
NDArray* bias = nullptr; NDArray* bias = nullptr;
if (block.width() > 2) if (block.width() > 2) {
bias = INPUT_VARIABLE(2); bias = INPUT_VARIABLE(2);
REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str());
}
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis); std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
@ -80,13 +84,16 @@ namespace ops {
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
const int dimC = isNCHW ? 1 : input->rankOf() - 1; const int dimC = isNCHW ? 1 : input->rankOf() - 1;
REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str());
std::vector<int> axis = *block.getIArguments(); std::vector<int> axis = *block.getIArguments();
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis); std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
if(bias != nullptr) { if(bias != nullptr) {
REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str());
// eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true); // eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true);
eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true); eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}));
} }
NDArray standardized(input->shapeInfo(), false, block.launchContext()); NDArray standardized(input->shapeInfo(), false, block.launchContext());
@ -99,7 +106,7 @@ namespace ops {
standardizeOp.execute(inputs, outputs, targs, longAxis, bargs); standardizeOp.execute(inputs, outputs, targs, longAxis, bargs);
standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr); standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr);
standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true); standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}));
nd4j::ops::standardize_bp standardizeBp; nd4j::ops::standardize_bp standardizeBp;
// eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx); // eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx);

View File

@ -100,7 +100,7 @@ static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int thr
BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& VT, const bool fullUV, const bool calcUV) { static void svdQR(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* VT, const bool fullUV, const bool calcUV) {
// since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f' // since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f'
// we make this function to have deal with 2 valid cases only: // we make this function to have deal with 2 valid cases only:
@ -113,59 +113,59 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
// U [m, m] or [m, n] if fullUV = false and m > n // U [m, m] or [m, n] if fullUV = false and m > n
// VT [n, n] or [m, n] if fullUV = false and m < n // VT [n, n] or [m, n] if fullUV = false and m < n
if(A.rankOf() != 2) if(A->rankOf() != 2)
throw std::runtime_error("svdQR: rank of A array is not equal 2 !"); throw std::runtime_error("svdQR: rank of A array is not equal 2 !");
auto m = A.sizeAt(0); auto m = A->sizeAt(0);
auto n = A.sizeAt(1); auto n = A->sizeAt(1);
const int minDim = m < n ? m : n; const int minDim = m < n ? m : n;
const char orderA = A.ordering(); const char orderA = A->ordering();
if(m < n) if(m < n)
throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !"); throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !");
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S)) if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
throw std::runtime_error("svdQR: wrong shape of S array !"); throw std::runtime_error("svdQR: wrong shape of S array !");
if(calcUV) { if(calcUV) {
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U)) if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
throw std::runtime_error("svdQR: wrong shape of U array !"); throw std::runtime_error("svdQR: wrong shape of U array !");
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U)) else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
throw std::runtime_error("svdQR: wrong shape of U array !"); throw std::runtime_error("svdQR: wrong shape of U array !");
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&VT)) if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(VT))
throw std::runtime_error("svdQR: wrong shape of VT array !"); throw std::runtime_error("svdQR: wrong shape of VT array !");
else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(&VT)) else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(VT))
throw std::runtime_error("svdQR: wrong shape of VT array !"); throw std::runtime_error("svdQR: wrong shape of VT array !");
} }
NDArray* pA = const_cast<NDArray*>(&A); NDArray* pA = const_cast<NDArray*>(A);
NDArray* pS = &S; NDArray* pS = S;
NDArray* pU = &U; NDArray* pU = U;
NDArray* pVT = &VT; NDArray* pVT = VT;
std::vector<NDArray*> toDelete; std::vector<NDArray*> toDelete;
if(pA->ews() != 1 || pA->ordering() == 'c') { if(pA->ews() != 1 || pA->ordering() == 'c') {
pA = A.dup('f'); pA = A->dup('f');
toDelete.push_back(pA); toDelete.push_back(pA);
} }
if(S.ews() != 1) { if(S->ews() != 1) {
pS = S.dup('f'); pS = S->dup('f');
toDelete.push_back(pS); toDelete.push_back(pS);
} }
if(calcUV) { if(calcUV) {
if(pU->ews() != 1 || pU->ordering() == 'c') { if(pU->ews() != 1 || pU->ordering() == 'c') {
pU = U.dup('f'); pU = U->dup('f');
toDelete.push_back(pU); toDelete.push_back(pU);
} }
if(pVT->ews() != 1 || pVT->ordering() == 'c') { if(pVT->ews() != 1 || pVT->ordering() == 'c') {
pVT = VT.dup('f'); pVT = VT->dup('f');
toDelete.push_back(pVT); toDelete.push_back(pVT);
} }
} }
@ -183,9 +183,9 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
// query working space of SVD // query working space of SVD
int lwork = 0; int lwork = 0;
if(A.dataType() == DataType::DOUBLE) if(A->dataType() == DataType::DOUBLE)
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork); status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
else if(A.dataType() == DataType::FLOAT32) else if(A->dataType() == DataType::FLOAT32)
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork); status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
else else
throw std::invalid_argument("svdQR: given data type is unsupported !"); throw std::invalid_argument("svdQR: given data type is unsupported !");
@ -195,7 +195,7 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
// allocate memory for dWork // allocate memory for dWork
void* dWork = nullptr; void* dWork = nullptr;
cudaError_t status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); cudaError_t status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork);
if(status2 != cudaSuccess) if(status2 != cudaSuccess)
throw cuda_exception::build("svdQR: cuda failed !", status2); throw cuda_exception::build("svdQR: cuda failed !", status2);
@ -226,11 +226,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
NDArray::prepareSpecialUse({pS, pU, pVT}, {pA}); NDArray::prepareSpecialUse({pS, pU, pVT}, {pA});
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(A.dataType() == DataType::DOUBLE) { if(A->dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo); status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
} }
else if(A.dataType() == DataType::FLOAT32) { else if(A->dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo); status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
} }
else else
throw std::invalid_argument("svdQR: given data type is unsupported !"); throw std::invalid_argument("svdQR: given data type is unsupported !");
@ -242,11 +242,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
NDArray::registerSpecialUse({pS, pU, pVT}, {pA}); NDArray::registerSpecialUse({pS, pU, pVT}, {pA});
S.assign(pS); S->assign(pS);
if(calcUV) { if(calcUV) {
U.assign(pU); U->assign(pU);
VT.assign(pVT); VT->assign(pVT);
} }
for (int i = toDelete.size() - 1; i >= 0; --i) for (int i = toDelete.size() - 1; i >= 0; --i)
@ -266,62 +266,62 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) { static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) {
// A [m, n] // A [m, n]
// S [n] // S [n]
// U [m, m] or [m, n] if fullUV = false and m > n // U [m, m] or [m, n] if fullUV = false and m > n
// V [n, n] or [n, m] if fullUV = false and m < n // V [n, n] or [n, m] if fullUV = false and m < n
if(A.rankOf() != 2) if(A->rankOf() != 2)
throw std::runtime_error("svdJcb: rank of A array is not equal 2 !"); throw std::runtime_error("svdJcb: rank of A array is not equal 2 !");
auto m = A.sizeAt(0); auto m = A->sizeAt(0);
auto n = A.sizeAt(1); auto n = A->sizeAt(1);
const int minDim = m < n ? m : n; const int minDim = m < n ? m : n;
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S)) if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
throw std::runtime_error("svdJcb: wrong shape of S array !"); throw std::runtime_error("svdJcb: wrong shape of S array !");
if(calcUV) { if(calcUV) {
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U)) if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
throw std::runtime_error("svdJcb: wrong shape of U array !"); throw std::runtime_error("svdJcb: wrong shape of U array !");
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U)) else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
throw std::runtime_error("svdJcb: wrong shape of U array !"); throw std::runtime_error("svdJcb: wrong shape of U array !");
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&V)) if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(V))
throw std::runtime_error("svdJcb: wrong shape of V array !"); throw std::runtime_error("svdJcb: wrong shape of V array !");
else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(&V)) else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(V))
throw std::runtime_error("svdJcb: wrong shape of V array !"); throw std::runtime_error("svdJcb: wrong shape of V array !");
} }
NDArray* pA = const_cast<NDArray*>(&A); NDArray* pA = const_cast<NDArray*>(A);
NDArray* pS = &S; NDArray* pS = S;
NDArray* pU = &U; NDArray* pU = U;
NDArray* pV = &V; NDArray* pV = V;
std::vector<NDArray*> toDelete; std::vector<NDArray*> toDelete;
if(pA->ews() != 1 || pA->ordering() == 'c') { if(pA->ews() != 1 || pA->ordering() == 'c') {
pA = A.dup('f'); pA = A->dup('f');
toDelete.push_back(pA); toDelete.push_back(pA);
} }
if(S.ews() != 1) { if(S->ews() != 1) {
pS = S.dup('f'); pS = S->dup('f');
toDelete.push_back(pS); toDelete.push_back(pS);
} }
if(calcUV) { if(calcUV) {
if(pU->ews() != 1 || pU->ordering() == 'c') { if(pU->ews() != 1 || pU->ordering() == 'c') {
pU = U.dup('f'); pU = U->dup('f');
toDelete.push_back(pU); toDelete.push_back(pU);
} }
if(pV->ews() != 1 || pV->ordering() == 'c') { if(pV->ews() != 1 || pV->ordering() == 'c') {
pV = V.dup('f'); pV = V->dup('f');
toDelete.push_back(pV); toDelete.push_back(pV);
} }
} }
@ -362,10 +362,10 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
// query working space of SVD // query working space of SVD
int lwork = 0; int lwork = 0;
if(A.dataType() == DataType::DOUBLE) if(A->dataType() == DataType::DOUBLE)
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams); status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams);
else if(A.dataType() == DataType::FLOAT32) else if(A->dataType() == DataType::FLOAT32)
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams); status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams);
else else
throw std::invalid_argument("svdJcb: given data type is unsupported !"); throw std::invalid_argument("svdJcb: given data type is unsupported !");
@ -374,7 +374,7 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
// allocate memory dWork // allocate memory dWork
void* dWork = nullptr; void* dWork = nullptr;
auto status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); auto status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork);
if(status2 != cudaSuccess) if(status2 != cudaSuccess)
throw cuda_exception::build("svdJcb: cuda failed !", status2); throw cuda_exception::build("svdJcb: cuda failed !", status2);
@ -383,11 +383,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(A.dataType() == DataType::DOUBLE) { if(A->dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams); status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
} }
else if(A.dataType() == DataType::FLOAT32) { else if(A->dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams); status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
} }
else else
throw std::invalid_argument("svdJcb: given data type is unsupported !"); throw std::invalid_argument("svdJcb: given data type is unsupported !");
@ -399,11 +399,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
NDArray::registerSpecialUse({pS, pU, pV}, {pA}); NDArray::registerSpecialUse({pS, pU, pV}, {pA});
S.assign(pS); S->assign(pS);
if(calcUV) { if(calcUV) {
U.assign(pU); U->assign(pU);
V.assign(pV); V->assign(pV);
} }
for (int i = toDelete.size() - 1; i >= 0; --i) for (int i = toDelete.size() - 1; i >= 0; --i)
@ -422,67 +422,67 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) { static void svdBatched(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) {
// A [..., m, n] // A [..., m, n]
// S [..., n] // S [..., n]
// U [..., m, m] or [..., m, n] if fullUV = false and m > n // U [..., m, m] or [..., m, n] if fullUV = false and m > n
// V [..., n, n] or [..., n, m] if fullUV = false and m < n // V [..., n, n] or [..., n, m] if fullUV = false and m < n
auto m = A.sizeAt(-2); auto m = A->sizeAt(-2);
auto n = A.sizeAt(-1); auto n = A->sizeAt(-1);
const int minDim = m < n ? m : n; const int minDim = m < n ? m : n;
const Nd4jLong bS = A.lengthOf() / (m * n); const Nd4jLong bS = A->lengthOf() / (m * n);
if(m > 32 || n > 32) if(m > 32 || n > 32)
throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !"); throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !");
if(minDim != S.sizeAt(-1)) if(minDim != S->sizeAt(-1))
throw std::runtime_error("svdBatched: wrong shape of S array !"); throw std::runtime_error("svdBatched: wrong shape of S array !");
if(calcUV) { if(calcUV) {
if(U.sizeAt(-2) != m) if(U->sizeAt(-2) != m)
throw std::runtime_error("svdBatched: wrong shape of U array !"); throw std::runtime_error("svdBatched: wrong shape of U array !");
if(U.sizeAt(-1) != (fullUV ? m : minDim)) if(U->sizeAt(-1) != (fullUV ? m : minDim))
throw std::runtime_error("svdBatched: wrong shape of U array !"); throw std::runtime_error("svdBatched: wrong shape of U array !");
if(U.lengthOf() / (U.sizeAt(-2) * U.sizeAt(-1)) != bS) if(U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS)
throw std::runtime_error("svdBatched: wrong shape of U array !"); throw std::runtime_error("svdBatched: wrong shape of U array !");
if(V.sizeAt(-2) != n) if(V->sizeAt(-2) != n)
throw std::runtime_error("svdBatched: wrong shape of V array !"); throw std::runtime_error("svdBatched: wrong shape of V array !");
if(V.sizeAt(-1) != (fullUV ? n : minDim)) if(V->sizeAt(-1) != (fullUV ? n : minDim))
throw std::runtime_error("svdBatched: wrong shape of V array !"); throw std::runtime_error("svdBatched: wrong shape of V array !");
if(V.lengthOf() / (V.sizeAt(-2) * V.sizeAt(-1)) != bS) if(V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS)
throw std::runtime_error("svdBatched: wrong shape of V array !"); throw std::runtime_error("svdBatched: wrong shape of V array !");
} }
NDArray* pA = const_cast<NDArray*>(&A); NDArray* pA = const_cast<NDArray*>(A);
NDArray* pS = &S; NDArray* pS = S;
NDArray* pU = &U; NDArray* pU = U;
NDArray* pV = &V; NDArray* pV = V;
std::vector<NDArray*> toDelete; std::vector<NDArray*> toDelete;
if(pA->ews() != 1 || pA->ordering() == 'c') { if(pA->ews() != 1 || pA->ordering() == 'c') {
pA = A.dup('f'); pA = A->dup('f');
toDelete.push_back(pA); toDelete.push_back(pA);
} }
if(S.ews() != 1) { if(S->ews() != 1) {
pS = S.dup('f'); pS = S->dup('f');
toDelete.push_back(pS); toDelete.push_back(pS);
} }
if(calcUV) { if(calcUV) {
if(pU->ews() != 1 || pU->ordering() == 'c') { if(pU->ews() != 1 || pU->ordering() == 'c') {
pU = U.dup('f'); pU = U->dup('f');
toDelete.push_back(pU); toDelete.push_back(pU);
} }
if(pV->ews() != 1 || pV->ordering() == 'c') { if(pV->ews() != 1 || pV->ordering() == 'c') {
pV = V.dup('f'); pV = V->dup('f');
toDelete.push_back(pV); toDelete.push_back(pV);
} }
} }
@ -532,10 +532,10 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
// query working space of SVD // query working space of SVD
int lwork = 0; int lwork = 0;
if(A.dataType() == DataType::DOUBLE) if(A->dataType() == DataType::DOUBLE)
status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS); status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS);
else if(A.dataType() == DataType::FLOAT32) else if(A->dataType() == DataType::FLOAT32)
status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS); status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS);
else else
throw std::invalid_argument("svdBatched: given data type is unsupported !"); throw std::invalid_argument("svdBatched: given data type is unsupported !");
@ -544,7 +544,7 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
// allocate memory dWork // allocate memory dWork
void* dWork = nullptr; void* dWork = nullptr;
status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork); status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork);
if(status2 != cudaSuccess) if(status2 != cudaSuccess)
throw cuda_exception::build("svdBatched: cuda failed !", status2); throw cuda_exception::build("svdBatched: cuda failed !", status2);
status2 = cudaDeviceSynchronize(); status2 = cudaDeviceSynchronize();
@ -556,11 +556,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
NDArray::prepareSpecialUse({pS, pU, pV}, {pA}); NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
// choose appropriate cuda gemm api depending on data types // choose appropriate cuda gemm api depending on data types
if(A.dataType() == DataType::DOUBLE) { if(A->dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams, bS); status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams, bS);
} }
else if(A.dataType() == DataType::FLOAT32) { else if(A->dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams, bS); status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams, bS);
} }
else else
throw std::invalid_argument("svdBatched: given data type is unsupported !"); throw std::invalid_argument("svdBatched: given data type is unsupported !");
@ -572,11 +572,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
NDArray::registerSpecialUse({pS, pU, pV}, {pA}); NDArray::registerSpecialUse({pS, pU, pV}, {pA});
S.assign(pS); S->assign(pS);
if(calcUV) { if(calcUV) {
U.assign(pU); U->assign(pU);
V.assign(pV); V->assign(pV);
} }
for (int i = toDelete.size() - 1; i >= 0; --i) for (int i = toDelete.size() - 1; i >= 0; --i)
@ -603,8 +603,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
NDArray* V = outArrs[2]; NDArray* V = outArrs[2];
if(x->rankOf() == 2) { if(x->rankOf() == 2) {
// svdQR(context, *x, *S, *U, VT, fullUV, calcUV); // svdQR(context, x, S, U, VT, fullUV, calcUV);
svdJcb(context, *x, *S, *U, *V, fullUV, calcUV); svdJcb(context, x, S, U, V, fullUV, calcUV);
} }
else { else {
@ -621,7 +621,7 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
} }
for (int i = 0; i < tadsX->size(); ++i) for (int i = 0; i < tadsX->size(); ++i)
svdJcb(context, *tadsX->at(i), *tadsS->at(i), calcUV ? *tadsU->at(i) : *S, calcUV ? *tadsV->at(i) : *S, fullUV, calcUV); svdJcb(context, tadsX->at(i), tadsS->at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV);
delete tadsX; delete tadsX;
delete tadsS; delete tadsS;

View File

@ -245,8 +245,8 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_3) {
TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
auto g = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto g = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
auto b = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto b = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
nd4j::ops::layer_norm op; nd4j::ops::layer_norm op;
auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); auto result = op.execute({&x, &g, &b}, {}, {0}, {false});
@ -256,8 +256,8 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
auto g = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto g = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
auto b = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto b = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0., 0., 0., 0., 0.}); auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0., 0., 0., 0., 0.});
nd4j::ops::layer_norm_bp op; nd4j::ops::layer_norm_bp op;
@ -266,6 +266,26 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
delete result; delete result;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_2) {
NDArray x('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32);
NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, nd4j::DataType::FLOAT32);
NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32);
NDArray gradI('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32);
NDArray gradG('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradB('c', {4}, nd4j::DataType::FLOAT32);
x.linspace(-20, 0.5);
gradO.linspace(-4, 0.05);
nd4j::ops::layer_norm_bp op;
auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, {}, {1,2,3}, {true});
ASSERT_EQ(Status::OK(), status);
}
TEST_F(DeclarableOpsTests15, test_hashCode_1) { TEST_F(DeclarableOpsTests15, test_hashCode_1) {
auto x = NDArrayFactory::create<int>('c', {10}); auto x = NDArrayFactory::create<int>('c', {10});
auto y = NDArrayFactory::create<int>('c', {10}); auto y = NDArrayFactory::create<int>('c', {10});