Shyrma fix2 (#186)
* - further work on layer_norm Signed-off-by: Yurii <yurii@skymind.io> * - further work on layer_norm 2 Signed-off-by: Yurii <yurii@skymind.io> * - correct helpers for svd cuda Signed-off-by: Yurii <yurii@skymind.io>master
parent
650539528c
commit
2144941313
|
@ -38,9 +38,13 @@ namespace ops {
|
||||||
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
|
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
|
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str());
|
||||||
|
|
||||||
NDArray* bias = nullptr;
|
NDArray* bias = nullptr;
|
||||||
if (block.width() > 2)
|
if (block.width() > 2) {
|
||||||
bias = INPUT_VARIABLE(2);
|
bias = INPUT_VARIABLE(2);
|
||||||
|
REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
|
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
|
||||||
|
|
||||||
|
@ -80,13 +84,16 @@ namespace ops {
|
||||||
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
|
const bool isNCHW = block.getBArguments()->size() > 0 ? B_ARG(0) : true; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
|
const int dimC = isNCHW ? 1 : input->rankOf() - 1;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(gain->rankOf() == 1 && gain->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of gain array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(gain).c_str());
|
||||||
|
|
||||||
std::vector<int> axis = *block.getIArguments();
|
std::vector<int> axis = *block.getIArguments();
|
||||||
|
|
||||||
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
|
std::vector<Nd4jLong> longAxis = ArrayUtils::toLongVector(axis);
|
||||||
|
|
||||||
if(bias != nullptr) {
|
if(bias != nullptr) {
|
||||||
|
REQUIRE_TRUE(bias->rankOf() == 1 && bias->sizeAt(0) == input->sizeAt(dimC), 0, "LAYER_NORM_BP OP: wrong shape of bias array, expected is {%i}, but got %s instead !", input->sizeAt(dimC), ShapeUtils::shapeAsString(bias).c_str());
|
||||||
// eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true);
|
// eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, {0}, true);
|
||||||
eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true);
|
eps->reduceAlongDimension(nd4j::reduce::Sum, dLdb, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}));
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray standardized(input->shapeInfo(), false, block.launchContext());
|
NDArray standardized(input->shapeInfo(), false, block.launchContext());
|
||||||
|
@ -99,7 +106,7 @@ namespace ops {
|
||||||
|
|
||||||
standardizeOp.execute(inputs, outputs, targs, longAxis, bargs);
|
standardizeOp.execute(inputs, outputs, targs, longAxis, bargs);
|
||||||
standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr);
|
standardized.applyPairwiseTransform(nd4j::pairwise::Multiply, eps, &standardized, nullptr);
|
||||||
standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}), true);
|
standardized.reduceAlongDimension(nd4j::reduce::Sum, dLdg, ShapeUtils::evalDimsToExclude(input->rankOf(), {dimC}));
|
||||||
|
|
||||||
nd4j::ops::standardize_bp standardizeBp;
|
nd4j::ops::standardize_bp standardizeBp;
|
||||||
// eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx);
|
// eps->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Multiply(), gain, dLdx);
|
||||||
|
|
|
@ -100,7 +100,7 @@ static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int thr
|
||||||
BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& VT, const bool fullUV, const bool calcUV) {
|
static void svdQR(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* VT, const bool fullUV, const bool calcUV) {
|
||||||
|
|
||||||
// since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f'
|
// since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f'
|
||||||
// we make this function to have deal with 2 valid cases only:
|
// we make this function to have deal with 2 valid cases only:
|
||||||
|
@ -113,59 +113,59 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
|
||||||
// U [m, m] or [m, n] if fullUV = false and m > n
|
// U [m, m] or [m, n] if fullUV = false and m > n
|
||||||
// VT [n, n] or [m, n] if fullUV = false and m < n
|
// VT [n, n] or [m, n] if fullUV = false and m < n
|
||||||
|
|
||||||
if(A.rankOf() != 2)
|
if(A->rankOf() != 2)
|
||||||
throw std::runtime_error("svdQR: rank of A array is not equal 2 !");
|
throw std::runtime_error("svdQR: rank of A array is not equal 2 !");
|
||||||
|
|
||||||
auto m = A.sizeAt(0);
|
auto m = A->sizeAt(0);
|
||||||
auto n = A.sizeAt(1);
|
auto n = A->sizeAt(1);
|
||||||
const int minDim = m < n ? m : n;
|
const int minDim = m < n ? m : n;
|
||||||
const char orderA = A.ordering();
|
const char orderA = A->ordering();
|
||||||
|
|
||||||
if(m < n)
|
if(m < n)
|
||||||
throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !");
|
throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !");
|
||||||
|
|
||||||
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S))
|
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
|
||||||
throw std::runtime_error("svdQR: wrong shape of S array !");
|
throw std::runtime_error("svdQR: wrong shape of S array !");
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
|
|
||||||
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U))
|
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
|
||||||
throw std::runtime_error("svdQR: wrong shape of U array !");
|
throw std::runtime_error("svdQR: wrong shape of U array !");
|
||||||
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U))
|
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
|
||||||
throw std::runtime_error("svdQR: wrong shape of U array !");
|
throw std::runtime_error("svdQR: wrong shape of U array !");
|
||||||
|
|
||||||
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&VT))
|
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(VT))
|
||||||
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
||||||
else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(&VT))
|
else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(VT))
|
||||||
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray* pA = const_cast<NDArray*>(&A);
|
NDArray* pA = const_cast<NDArray*>(A);
|
||||||
NDArray* pS = &S;
|
NDArray* pS = S;
|
||||||
NDArray* pU = &U;
|
NDArray* pU = U;
|
||||||
NDArray* pVT = &VT;
|
NDArray* pVT = VT;
|
||||||
|
|
||||||
std::vector<NDArray*> toDelete;
|
std::vector<NDArray*> toDelete;
|
||||||
|
|
||||||
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
||||||
pA = A.dup('f');
|
pA = A->dup('f');
|
||||||
toDelete.push_back(pA);
|
toDelete.push_back(pA);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(S.ews() != 1) {
|
if(S->ews() != 1) {
|
||||||
pS = S.dup('f');
|
pS = S->dup('f');
|
||||||
toDelete.push_back(pS);
|
toDelete.push_back(pS);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
|
|
||||||
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
||||||
pU = U.dup('f');
|
pU = U->dup('f');
|
||||||
toDelete.push_back(pU);
|
toDelete.push_back(pU);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(pVT->ews() != 1 || pVT->ordering() == 'c') {
|
if(pVT->ews() != 1 || pVT->ordering() == 'c') {
|
||||||
pVT = VT.dup('f');
|
pVT = VT->dup('f');
|
||||||
toDelete.push_back(pVT);
|
toDelete.push_back(pVT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -183,9 +183,9 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
|
||||||
|
|
||||||
// query working space of SVD
|
// query working space of SVD
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
if(A.dataType() == DataType::DOUBLE)
|
if(A->dataType() == DataType::DOUBLE)
|
||||||
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
|
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
|
||||||
else if(A.dataType() == DataType::FLOAT32)
|
else if(A->dataType() == DataType::FLOAT32)
|
||||||
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
|
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||||
|
@ -195,7 +195,7 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
|
||||||
|
|
||||||
// allocate memory for dWork
|
// allocate memory for dWork
|
||||||
void* dWork = nullptr;
|
void* dWork = nullptr;
|
||||||
cudaError_t status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
|
cudaError_t status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork);
|
||||||
if(status2 != cudaSuccess)
|
if(status2 != cudaSuccess)
|
||||||
throw cuda_exception::build("svdQR: cuda failed !", status2);
|
throw cuda_exception::build("svdQR: cuda failed !", status2);
|
||||||
|
|
||||||
|
@ -226,11 +226,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
|
||||||
NDArray::prepareSpecialUse({pS, pU, pVT}, {pA});
|
NDArray::prepareSpecialUse({pS, pU, pVT}, {pA});
|
||||||
|
|
||||||
// choose appropriate cuda gemm api depending on data types
|
// choose appropriate cuda gemm api depending on data types
|
||||||
if(A.dataType() == DataType::DOUBLE) {
|
if(A->dataType() == DataType::DOUBLE) {
|
||||||
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
|
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
|
||||||
}
|
}
|
||||||
else if(A.dataType() == DataType::FLOAT32) {
|
else if(A->dataType() == DataType::FLOAT32) {
|
||||||
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
|
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||||
|
@ -242,11 +242,11 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
|
||||||
|
|
||||||
NDArray::registerSpecialUse({pS, pU, pVT}, {pA});
|
NDArray::registerSpecialUse({pS, pU, pVT}, {pA});
|
||||||
|
|
||||||
S.assign(pS);
|
S->assign(pS);
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
U.assign(pU);
|
U->assign(pU);
|
||||||
VT.assign(pVT);
|
VT->assign(pVT);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = toDelete.size() - 1; i >= 0; --i)
|
for (int i = toDelete.size() - 1; i >= 0; --i)
|
||||||
|
@ -266,62 +266,62 @@ static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, ND
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) {
|
static void svdJcb(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) {
|
||||||
|
|
||||||
// A [m, n]
|
// A [m, n]
|
||||||
// S [n]
|
// S [n]
|
||||||
// U [m, m] or [m, n] if fullUV = false and m > n
|
// U [m, m] or [m, n] if fullUV = false and m > n
|
||||||
// V [n, n] or [n, m] if fullUV = false and m < n
|
// V [n, n] or [n, m] if fullUV = false and m < n
|
||||||
|
|
||||||
if(A.rankOf() != 2)
|
if(A->rankOf() != 2)
|
||||||
throw std::runtime_error("svdJcb: rank of A array is not equal 2 !");
|
throw std::runtime_error("svdJcb: rank of A array is not equal 2 !");
|
||||||
|
|
||||||
auto m = A.sizeAt(0);
|
auto m = A->sizeAt(0);
|
||||||
auto n = A.sizeAt(1);
|
auto n = A->sizeAt(1);
|
||||||
const int minDim = m < n ? m : n;
|
const int minDim = m < n ? m : n;
|
||||||
|
|
||||||
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S))
|
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
|
||||||
throw std::runtime_error("svdJcb: wrong shape of S array !");
|
throw std::runtime_error("svdJcb: wrong shape of S array !");
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
|
|
||||||
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U))
|
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
|
||||||
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
||||||
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U))
|
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
|
||||||
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
||||||
|
|
||||||
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&V))
|
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(V))
|
||||||
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
||||||
else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(&V))
|
else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(V))
|
||||||
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray* pA = const_cast<NDArray*>(&A);
|
NDArray* pA = const_cast<NDArray*>(A);
|
||||||
NDArray* pS = &S;
|
NDArray* pS = S;
|
||||||
NDArray* pU = &U;
|
NDArray* pU = U;
|
||||||
NDArray* pV = &V;
|
NDArray* pV = V;
|
||||||
|
|
||||||
std::vector<NDArray*> toDelete;
|
std::vector<NDArray*> toDelete;
|
||||||
|
|
||||||
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
||||||
pA = A.dup('f');
|
pA = A->dup('f');
|
||||||
toDelete.push_back(pA);
|
toDelete.push_back(pA);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(S.ews() != 1) {
|
if(S->ews() != 1) {
|
||||||
pS = S.dup('f');
|
pS = S->dup('f');
|
||||||
toDelete.push_back(pS);
|
toDelete.push_back(pS);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
|
|
||||||
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
||||||
pU = U.dup('f');
|
pU = U->dup('f');
|
||||||
toDelete.push_back(pU);
|
toDelete.push_back(pU);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(pV->ews() != 1 || pV->ordering() == 'c') {
|
if(pV->ews() != 1 || pV->ordering() == 'c') {
|
||||||
pV = V.dup('f');
|
pV = V->dup('f');
|
||||||
toDelete.push_back(pV);
|
toDelete.push_back(pV);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -362,10 +362,10 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
|
||||||
|
|
||||||
// query working space of SVD
|
// query working space of SVD
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
if(A.dataType() == DataType::DOUBLE)
|
if(A->dataType() == DataType::DOUBLE)
|
||||||
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams);
|
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams);
|
||||||
else if(A.dataType() == DataType::FLOAT32)
|
else if(A->dataType() == DataType::FLOAT32)
|
||||||
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams);
|
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams);
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||||
|
|
||||||
|
@ -374,7 +374,7 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
|
||||||
|
|
||||||
// allocate memory dWork
|
// allocate memory dWork
|
||||||
void* dWork = nullptr;
|
void* dWork = nullptr;
|
||||||
auto status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
|
auto status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork);
|
||||||
if(status2 != cudaSuccess)
|
if(status2 != cudaSuccess)
|
||||||
throw cuda_exception::build("svdJcb: cuda failed !", status2);
|
throw cuda_exception::build("svdJcb: cuda failed !", status2);
|
||||||
|
|
||||||
|
@ -383,11 +383,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
|
||||||
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
|
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
// choose appropriate cuda gemm api depending on data types
|
// choose appropriate cuda gemm api depending on data types
|
||||||
if(A.dataType() == DataType::DOUBLE) {
|
if(A->dataType() == DataType::DOUBLE) {
|
||||||
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
|
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
|
||||||
}
|
}
|
||||||
else if(A.dataType() == DataType::FLOAT32) {
|
else if(A->dataType() == DataType::FLOAT32) {
|
||||||
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
|
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||||
|
@ -399,11 +399,11 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
|
||||||
|
|
||||||
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
|
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
S.assign(pS);
|
S->assign(pS);
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
U.assign(pU);
|
U->assign(pU);
|
||||||
V.assign(pV);
|
V->assign(pV);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = toDelete.size() - 1; i >= 0; --i)
|
for (int i = toDelete.size() - 1; i >= 0; --i)
|
||||||
|
@ -422,67 +422,67 @@ static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, N
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) {
|
static void svdBatched(nd4j::LaunchContext* context, const NDArray* A, NDArray* S, NDArray* U, NDArray* V, const bool fullUV, const bool calcUV) {
|
||||||
|
|
||||||
// A [..., m, n]
|
// A [..., m, n]
|
||||||
// S [..., n]
|
// S [..., n]
|
||||||
// U [..., m, m] or [..., m, n] if fullUV = false and m > n
|
// U [..., m, m] or [..., m, n] if fullUV = false and m > n
|
||||||
// V [..., n, n] or [..., n, m] if fullUV = false and m < n
|
// V [..., n, n] or [..., n, m] if fullUV = false and m < n
|
||||||
|
|
||||||
auto m = A.sizeAt(-2);
|
auto m = A->sizeAt(-2);
|
||||||
auto n = A.sizeAt(-1);
|
auto n = A->sizeAt(-1);
|
||||||
const int minDim = m < n ? m : n;
|
const int minDim = m < n ? m : n;
|
||||||
const Nd4jLong bS = A.lengthOf() / (m * n);
|
const Nd4jLong bS = A->lengthOf() / (m * n);
|
||||||
|
|
||||||
if(m > 32 || n > 32)
|
if(m > 32 || n > 32)
|
||||||
throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !");
|
throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !");
|
||||||
|
|
||||||
if(minDim != S.sizeAt(-1))
|
if(minDim != S->sizeAt(-1))
|
||||||
throw std::runtime_error("svdBatched: wrong shape of S array !");
|
throw std::runtime_error("svdBatched: wrong shape of S array !");
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
|
|
||||||
if(U.sizeAt(-2) != m)
|
if(U->sizeAt(-2) != m)
|
||||||
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
||||||
if(U.sizeAt(-1) != (fullUV ? m : minDim))
|
if(U->sizeAt(-1) != (fullUV ? m : minDim))
|
||||||
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
||||||
if(U.lengthOf() / (U.sizeAt(-2) * U.sizeAt(-1)) != bS)
|
if(U->lengthOf() / (U->sizeAt(-2) * U->sizeAt(-1)) != bS)
|
||||||
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
||||||
|
|
||||||
if(V.sizeAt(-2) != n)
|
if(V->sizeAt(-2) != n)
|
||||||
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
||||||
if(V.sizeAt(-1) != (fullUV ? n : minDim))
|
if(V->sizeAt(-1) != (fullUV ? n : minDim))
|
||||||
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
||||||
if(V.lengthOf() / (V.sizeAt(-2) * V.sizeAt(-1)) != bS)
|
if(V->lengthOf() / (V->sizeAt(-2) * V->sizeAt(-1)) != bS)
|
||||||
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray* pA = const_cast<NDArray*>(&A);
|
NDArray* pA = const_cast<NDArray*>(A);
|
||||||
NDArray* pS = &S;
|
NDArray* pS = S;
|
||||||
NDArray* pU = &U;
|
NDArray* pU = U;
|
||||||
NDArray* pV = &V;
|
NDArray* pV = V;
|
||||||
|
|
||||||
std::vector<NDArray*> toDelete;
|
std::vector<NDArray*> toDelete;
|
||||||
|
|
||||||
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
||||||
pA = A.dup('f');
|
pA = A->dup('f');
|
||||||
toDelete.push_back(pA);
|
toDelete.push_back(pA);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(S.ews() != 1) {
|
if(S->ews() != 1) {
|
||||||
pS = S.dup('f');
|
pS = S->dup('f');
|
||||||
toDelete.push_back(pS);
|
toDelete.push_back(pS);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
|
|
||||||
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
||||||
pU = U.dup('f');
|
pU = U->dup('f');
|
||||||
toDelete.push_back(pU);
|
toDelete.push_back(pU);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(pV->ews() != 1 || pV->ordering() == 'c') {
|
if(pV->ews() != 1 || pV->ordering() == 'c') {
|
||||||
pV = V.dup('f');
|
pV = V->dup('f');
|
||||||
toDelete.push_back(pV);
|
toDelete.push_back(pV);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -532,10 +532,10 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
|
||||||
|
|
||||||
// query working space of SVD
|
// query working space of SVD
|
||||||
int lwork = 0;
|
int lwork = 0;
|
||||||
if(A.dataType() == DataType::DOUBLE)
|
if(A->dataType() == DataType::DOUBLE)
|
||||||
status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS);
|
status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS);
|
||||||
else if(A.dataType() == DataType::FLOAT32)
|
else if(A->dataType() == DataType::FLOAT32)
|
||||||
status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS);
|
status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, &lwork, gesvdjParams, bS);
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdBatched: given data type is unsupported !");
|
throw std::invalid_argument("svdBatched: given data type is unsupported !");
|
||||||
|
|
||||||
|
@ -544,7 +544,7 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
|
||||||
|
|
||||||
// allocate memory dWork
|
// allocate memory dWork
|
||||||
void* dWork = nullptr;
|
void* dWork = nullptr;
|
||||||
status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
|
status2 = cudaMalloc((void**)&dWork , A->sizeOfT() * lwork);
|
||||||
if(status2 != cudaSuccess)
|
if(status2 != cudaSuccess)
|
||||||
throw cuda_exception::build("svdBatched: cuda failed !", status2);
|
throw cuda_exception::build("svdBatched: cuda failed !", status2);
|
||||||
status2 = cudaDeviceSynchronize();
|
status2 = cudaDeviceSynchronize();
|
||||||
|
@ -556,11 +556,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
|
||||||
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
|
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
// choose appropriate cuda gemm api depending on data types
|
// choose appropriate cuda gemm api depending on data types
|
||||||
if(A.dataType() == DataType::DOUBLE) {
|
if(A->dataType() == DataType::DOUBLE) {
|
||||||
status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams, bS);
|
status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<double*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<double*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams, bS);
|
||||||
}
|
}
|
||||||
else if(A.dataType() == DataType::FLOAT32) {
|
else if(A->dataType() == DataType::FLOAT32) {
|
||||||
status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams, bS);
|
status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), calcUV ? reinterpret_cast<float*>(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast<float*>(pV->getSpecialBuffer()) : nullptr, ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams, bS);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
throw std::invalid_argument("svdBatched: given data type is unsupported !");
|
throw std::invalid_argument("svdBatched: given data type is unsupported !");
|
||||||
|
@ -572,11 +572,11 @@ static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray&
|
||||||
|
|
||||||
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
|
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
S.assign(pS);
|
S->assign(pS);
|
||||||
|
|
||||||
if(calcUV) {
|
if(calcUV) {
|
||||||
U.assign(pU);
|
U->assign(pU);
|
||||||
V.assign(pV);
|
V->assign(pV);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = toDelete.size() - 1; i >= 0; --i)
|
for (int i = toDelete.size() - 1; i >= 0; --i)
|
||||||
|
@ -603,8 +603,8 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
|
||||||
NDArray* V = outArrs[2];
|
NDArray* V = outArrs[2];
|
||||||
|
|
||||||
if(x->rankOf() == 2) {
|
if(x->rankOf() == 2) {
|
||||||
// svdQR(context, *x, *S, *U, VT, fullUV, calcUV);
|
// svdQR(context, x, S, U, VT, fullUV, calcUV);
|
||||||
svdJcb(context, *x, *S, *U, *V, fullUV, calcUV);
|
svdJcb(context, x, S, U, V, fullUV, calcUV);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -621,7 +621,7 @@ void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArr
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < tadsX->size(); ++i)
|
for (int i = 0; i < tadsX->size(); ++i)
|
||||||
svdJcb(context, *tadsX->at(i), *tadsS->at(i), calcUV ? *tadsU->at(i) : *S, calcUV ? *tadsV->at(i) : *S, fullUV, calcUV);
|
svdJcb(context, tadsX->at(i), tadsS->at(i), calcUV ? tadsU->at(i) : nullptr, calcUV ? tadsV->at(i) : nullptr, fullUV, calcUV);
|
||||||
|
|
||||||
delete tadsX;
|
delete tadsX;
|
||||||
delete tadsS;
|
delete tadsS;
|
||||||
|
|
|
@ -245,8 +245,8 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_3) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
|
TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
||||||
auto g = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
auto g = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
|
||||||
auto b = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
auto b = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
|
||||||
|
|
||||||
nd4j::ops::layer_norm op;
|
nd4j::ops::layer_norm op;
|
||||||
auto result = op.execute({&x, &g, &b}, {}, {0}, {false});
|
auto result = op.execute({&x, &g, &b}, {}, {0}, {false});
|
||||||
|
@ -256,8 +256,8 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
||||||
auto g = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
auto g = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
|
||||||
auto b = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.});
|
auto b = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.});
|
||||||
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0., 0., 0., 0., 0.});
|
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0., 0., 0., 0., 0.});
|
||||||
|
|
||||||
nd4j::ops::layer_norm_bp op;
|
nd4j::ops::layer_norm_bp op;
|
||||||
|
@ -266,6 +266,26 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray gradI('c', {3, 4, 8, 8}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradG('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradB('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
x.linspace(-20, 0.5);
|
||||||
|
gradO.linspace(-4, 0.05);
|
||||||
|
|
||||||
|
nd4j::ops::layer_norm_bp op;
|
||||||
|
auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, {}, {1,2,3}, {true});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_hashCode_1) {
|
TEST_F(DeclarableOpsTests15, test_hashCode_1) {
|
||||||
auto x = NDArrayFactory::create<int>('c', {10});
|
auto x = NDArrayFactory::create<int>('c', {10});
|
||||||
auto y = NDArrayFactory::create<int>('c', {10});
|
auto y = NDArrayFactory::create<int>('c', {10});
|
||||||
|
|
Loading…
Reference in New Issue