Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
AlexDBlack 2019-11-06 13:28:03 +11:00
commit 7583ccfa15
52 changed files with 3728 additions and 1509 deletions

View File

@ -17,20 +17,20 @@ endif()
# -fsanitize=address
# -fsanitize=leak
if (APPLE)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D__APPLE_OS__=true -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -O0 -g -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true")
elseif(WIN32)
set(X86_BUILD true)
if (NOT CUDA_BLAS)
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -g -fPIC -std=c++11 -fmax-errors=2")
else()
set(CMAKE_CXX_FLAGS_RELEASE "-D_RELEASE=true /wd4804")
set(CMAKE_CXX_FLAGS_DEBUG " /FS /EHsc /wd4661 /wd4804 /wd4267 /wd4244 /wd4251 /wd4305")
endif()
else()
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -fPIC -std=c++11 -fmax-errors=2 -D_RELEASE=true")
set(CMAKE_CXX_FLAGS_DEBUG " -g -O0 -fPIC -std=c++11 -fmax-errors=2")
if (CPU_BLAS)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")

View File

@ -97,6 +97,8 @@ namespace nd4j {
static std::string shapeAsString(const int rank, const Nd4jLong* shapeInfo);
static std::string strideAsString(const NDArray* array);
static std::vector<Nd4jLong> shapeAsVector(const Nd4jLong* shapeInfo);
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
static Nd4jLong* evalDiagShapeInfo(const Nd4jLong* shapeInfo, nd4j::memory::Workspace* workspace);

View File

@ -39,7 +39,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
const int rows = matrix.sizeAt(0);
const int cols = matrix.sizeAt(1);
if(cols > rows) {
_transp = true;
@ -54,7 +54,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
_switchSize = switchSize;
_calcU = calcU;
_calcV = calcV;
_fullUV = fullUV;
_fullUV = fullUV;
if (_transp)
math::nd4j_swap<bool>(_calcU, _calcV);
@ -65,7 +65,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
if (_calcU)
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
else
else
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
_u.assign(0.);
@ -86,7 +86,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
const int rows = matrix.sizeAt(0);
const int cols = matrix.sizeAt(1);
if(cols > rows) {
_transp = true;
@ -101,7 +101,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
_switchSize = switchSize;
_calcU = calcU;
_calcV = calcV;
_fullUV = fullUV;
_fullUV = fullUV;
if (_transp)
math::nd4j_swap<bool>(_calcU, _calcV);
@ -112,7 +112,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
if (_calcU)
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
else
else
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
_u.assign(0.);
@ -130,13 +130,13 @@ void SVD<T>::deflation1(int col1, int shift, int ind, int size) {
if(ind <= 0)
throw std::runtime_error("ops::helpers::SVD::deflation1 method: input int must satisfy condition ind > 0 !");
int first = col1 + shift;
int first = col1 + shift;
T cos = _m.e<T>(first, first);
T sin = _m.e<T>(first+ind, first);
T denom = math::nd4j_sqrt<T, T>(cos*cos + sin*sin);
if (denom == (T)0.) {
_m.p(first+ind, first+ind, 0.f);
return;
}
@ -147,25 +147,25 @@ void SVD<T>::deflation1(int col1, int shift, int ind, int size) {
_m.p(first,first, denom);
_m.p(first+ind, first, 0.f);
_m.p(first+ind, first+ind, 0.f);
auto rotation = NDArrayFactory::create<T>(_m.ordering(), {2, 2}, _m.getContext());
rotation.p(0, 0, cos);
rotation.p(0, 1, -sin);
rotation.p(1, 0, sin);
rotation.p(1, 1, cos);
if (_calcU) {
if (_calcU) {
auto temp = _u({col1,col1+size+1, 0,0}, true);
JacobiSVD<T>::mulRotationOnRight(col1, col1+ind, temp, rotation);
}
else
JacobiSVD<T>::mulRotationOnRight(col1, col1+ind, _u, rotation);
JacobiSVD<T>::mulRotationOnRight(col1, col1+ind, _u, rotation);
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
void SVD<T>::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, int ind2, int size) {
if(ind1 >= ind2)
throw std::runtime_error("ops::helpers::SVD::deflation2 method: input intes must satisfy condition ind1 < ind2 !");
@ -175,9 +175,9 @@ void SVD<T>::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, i
T cos = _m.e<T>(col1M+ind1, col1M);
T sin = _m.e<T>(col1M+ind2, col1M);
T denom = math::nd4j_sqrt<T,T>(cos*cos + sin*sin);
if (denom == (T)0.) {
_m.p(col1M + ind1, col1M + ind1, _m.e<T>(col1M + ind2, col1M + ind2));
return;
}
@ -187,21 +187,21 @@ void SVD<T>::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, i
_m.p(col1M + ind1, col1M, denom);
_m.p(col1M + ind2, col1M + ind2, _m.e<T>(col1M + ind1, col1M + ind1));
_m.p(col1M + ind2, col1M, 0.f);
auto rotation = NDArrayFactory::create<T>(_m.ordering(), {2, 2}, _m.getContext());
rotation.p(0,0, cos);
rotation.p(1,1, cos);
rotation.p(0,1, -sin);
rotation.p(1,0, sin);
if (_calcU) {
auto temp = _u({col1U,col1U+size+1, 0,0}, true);
JacobiSVD<T>::mulRotationOnRight(col1U+ind1, col1U+ind2, temp, rotation);
}
else
JacobiSVD<T>::mulRotationOnRight(col1U+ind1, col1U+ind2, _u, rotation);
JacobiSVD<T>::mulRotationOnRight(col1U+ind1, col1U+ind2, _u, rotation);
if (_calcV) {
auto temp = _v({row1W,row1W+size, 0,0}, true);
JacobiSVD<T>::mulRotationOnRight(col1W+ind1, col1W+ind2, temp, rotation);
@ -209,17 +209,17 @@ void SVD<T>::deflation2(int col1U , int col1M, int row1W, int col1W, int ind1, i
}
//////////////////////////////////////////////////////////////////////////
// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) inclusively
// has effect on block from (col1+shift, col1+shift) to (col2+shift, col2+shift) inclusively
template <typename T>
void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int shift)
{
const int len = col2 + 1 - col1;
auto colVec0 = new NDArray(_m({col1+shift,col1+shift+len, col1+shift,col1+shift+1}, true));
auto diagInterval = _m({col1+shift, col1+shift+len, col1+shift,col1+shift+len}, true).diagonal('c');
const T almostZero = DataTypeUtils::min<T>();
T maxElem;
if(len == 1)
@ -229,55 +229,55 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e<T>(0);
T eps = math::nd4j_max<T>(almostZero, DataTypeUtils::eps<T>() * maxElem);
T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem);
T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem);
if(diagInterval->template e<T>(0) < epsBig)
diagInterval->p(Nd4jLong(0), epsBig);
for(int i=1; i < len; ++i)
if(math::nd4j_abs<T>(colVec0->template e<T>(i)) < eps)
colVec0->p(i, 0.f);
for(int i=1; i < len; i++)
if(diagInterval->template e<T>(i) < epsBig) {
deflation1(col1, shift, i, len);
deflation1(col1, shift, i, len);
for(int i = 0; i < len; ++i)
diagInterval->p(i, _m.e<T>(col1+shift+i,col1+shift+i));
}
{
bool totDefl = true;
bool totDefl = true;
for(int i=1; i < len; i++)
if(colVec0->template e<T>(i) >= almostZero) {
totDefl = false;
break;
}
int* permut = nullptr;
int* permut = nullptr;
ALLOCATE(permut, _m.getContext()->getWorkspace(), 3*_diagSize, int);
{
permut[0] = 0;
int p = 1;
int p = 1;
for(int i=1; i<len; ++i)
if(math::nd4j_abs<T>(diagInterval->template e<T>(i)) < almostZero)
permut[p++] = i;
permut[p++] = i;
int k = 1, m = ind+1;
for( ; p < len; ++p) {
if(k > ind)
permut[p] = m++;
if(k > ind)
permut[p] = m++;
else if(m >= len)
permut[p] = k++;
else if(diagInterval->template e<T>(k) < diagInterval->template e<T>(m))
permut[p] = m++;
else
else
permut[p] = k++;
}
}
if(totDefl) {
for(int i=1; i<len; ++i) {
int ki = permut[i];
@ -289,17 +289,17 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
}
}
}
int *tInd = permut + len;
int *tCol = permut + 2*len;
for(int m = 0; m < len; m++) {
tCol[m] = m;
tInd[m] = m;
}
for(int i = totDefl ? 0 : 1; i < len; i++) {
const int ki = permut[len - (totDefl ? i+1 : i)];
const int jac = tCol[ki];
@ -314,31 +314,31 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
colVec0->p(jac, colVec0->template e<T>(i));
colVec0->p(i, _e0);
}
NDArray* temp1 = nullptr, *temp2 = nullptr;
if (_calcU) {
auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true);
auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true);
auto temp3 = temp1;
temp1.assign(temp2);
temp2.assign(temp3);
}
temp2.assign(temp3);
}
else {
auto temp1 = _u({0,2, col1+i, col1+i+1}, true);
auto temp2 = _u({0,2, col1+jac, col1+jac+1}, true);
auto temp3 = temp1;
temp1.assign(temp2);
temp2.assign(temp3);
}
temp2.assign(temp3);
}
if(_calcV) {
auto temp1 = _v({row1W,row1W+len, col1W+i, col1W+i+1}, true);
auto temp2 = _v({row1W,row1W+len, col1W+jac, col1W+jac+1}, true);
auto temp3 = temp1;
temp1.assign(temp2);
temp2.assign(temp3);
temp2.assign(temp3);
}
const int tI = tInd[i];
tCol[tI] = jac;
tCol[ki] = i;
@ -348,13 +348,13 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
RELEASE(permut, _m.getContext()->getWorkspace());
}
{
int i = len-1;
while(i > 0 && (math::nd4j_abs<T>(diagInterval->template e<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0->template e<T>(i)) < almostZero))
--i;
for(; i > 1; --i) {
if( (diagInterval->template e<T>(i) - diagInterval->template e<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) {
if (math::nd4j_abs<T>(diagInterval->template e<T>(i) - diagInterval->template e<T>(i-1)) >= epsBig)
@ -362,7 +362,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len);
}
}
}
}
delete colVec0;
delete diagInterval;
@ -381,7 +381,7 @@ T SVD<T>::secularEq(const T diff, const NDArray& col0, const NDArray& diag, cons
item = col0.e<T>(j) / ((diagShifted.e<T>(j) - diff) * (diag.e<T>(j) + shift + diff));
res += item * col0.e<T>(j);
}
return res;
}
@ -389,52 +389,52 @@ T SVD<T>::secularEq(const T diff, const NDArray& col0, const NDArray& diag, cons
//////////////////////////////////////////////////////////////////////////
template <typename T>
void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArray& permut, NDArray& singVals, NDArray& shifts, NDArray& mus) {
auto len = col0.lengthOf();
auto curLen = len;
while(curLen > 1 && col0.e<T>(curLen-1) == (T)0.f)
--curLen;
for (int k = 0; k < len; ++k) {
if (col0.e<T>(k) == (T)0.f || curLen==1) {
singVals.p(k, k==0 ? col0.e<T>(0) : diag.e<T>(k));
mus.p(k, 0.f);
shifts.p(k, k==0 ? col0.e<T>(0) : diag.e<T>(k));
continue;
}
}
T left = diag.e<T>(k);
T right;
if(k==curLen-1)
right = diag.e<T>(curLen-1) + col0.reduceNumber(reduce::Norm2).e<T>(0);
else {
int l = k+1;
while(col0.e<T>(l) == (T)0.f) {
++l;
++l;
if(l >= curLen)
throw std::runtime_error("ops::helpers::SVD::calcSingVals method: l >= curLen !");
}
right = diag.e<T>(l);
}
T mid = left + (right - left) / (T)2.;
T fMid = secularEq(mid, col0, diag, permut, diag, 0.);
T shift = (k == curLen-1 || fMid > (T)0.) ? left : right;
auto diagShifted = diag - shift;
T muPrev, muCur;
if (shift == left) {
muPrev = (right - left) * 0.1;
if (k == curLen-1)
if (k == curLen-1)
muCur = right - left;
else
else
muCur = (right - left) * 0.5;
}
else {
@ -444,67 +444,67 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
T fPrev = secularEq(muPrev, col0, diag, permut, diagShifted, shift);
T fCur = secularEq(muCur, col0, diag, permut, diagShifted, shift);
if (math::nd4j_abs<T>(fPrev) < math::nd4j_abs<T>(fCur)) {
if (math::nd4j_abs<T>(fPrev) < math::nd4j_abs<T>(fCur)) {
math::nd4j_swap<T>(fPrev, fCur);
math::nd4j_swap<T>(muPrev, muCur);
}
bool useBisection = fPrev * fCur > (T)0.;
while (fCur != (T).0 &&
math::nd4j_abs<T>(muCur - muPrev) > (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(math::nd4j_abs<T>(muCur), math::nd4j_abs<T>(muPrev))
while (fCur != (T).0 &&
math::nd4j_abs<T>(muCur - muPrev) > (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(math::nd4j_abs<T>(muCur), math::nd4j_abs<T>(muPrev))
&& math::nd4j_abs<T>(fCur - fPrev) > DataTypeUtils::eps<T>() && !useBisection) {
T a = (fCur - fPrev) / ((T)1./muCur - (T)1./muPrev);
T jac = fCur - a / muCur;
T jac = fCur - a / muCur;
T muZero = -a/jac;
T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift);
T fZero = secularEq(muZero, col0, diag, permut, diagShifted, shift);
muPrev = muCur;
fPrev = fCur;
muCur = muZero;
fCur = fZero;
if (shift == left && (muCur < (T)0. || muCur > right - left))
fCur = fZero;
if (shift == left && (muCur < (T)0. || muCur > right - left))
useBisection = true;
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
useBisection = true;
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev))
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev) && math::nd4j_abs<T>(fCur - fPrev) > (T)16. * DataTypeUtils::eps<T>())
useBisection = true;
}
if (useBisection) {
T leftShifted, rightShifted;
if (shift == left) {
leftShifted = DataTypeUtils::min<T>();
rightShifted = (k==curLen-1) ? right : ((right - left) * (T)0.6);
rightShifted = (k==curLen-1) ? right : ((right - left) * (T)0.6);
}
else {
leftShifted = -(right - left) * (T)0.6;
rightShifted = -DataTypeUtils::min<T>();
}
T fLeft = secularEq(leftShifted, col0, diag, permut, diagShifted, shift);
T fRight = secularEq(rightShifted, col0, diag, permut, diagShifted, shift);
T fRight = secularEq(rightShifted, col0, diag, permut, diagShifted, shift);
// if(fLeft * fRight >= (T)0.)
// throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. !";
// throw "ops::helpers::SVD::calcSingVals method: fLeft * fRight >= (T)0. !";
while (rightShifted - leftShifted > (T)2.f * DataTypeUtils::eps<T>() * math::nd4j_max<T>(math::nd4j_abs<T>(leftShifted), math::nd4j_abs<T>(rightShifted))) {
T midShifted = (leftShifted + rightShifted) / (T)2.;
fMid = secularEq(midShifted, col0, diag, permut, diagShifted, shift);
if (fLeft * fMid < (T)0.)
rightShifted = midShifted;
if (fLeft * fMid < (T)0.)
rightShifted = midShifted;
else {
leftShifted = midShifted;
fLeft = fMid;
}
}
muCur = (leftShifted + rightShifted) / (T)2.;
}
}
singVals.p(k, shift + muCur);
shifts.p(k, shift);
mus.p(k, muCur);
@ -514,23 +514,23 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
//////////////////////////////////////////////////////////////////////////
template <typename T>
template <typename T>
void SVD<T>::perturb(const NDArray& col0, const NDArray& diag, const NDArray& permut, const NDArray& singVals, const NDArray& shifts, const NDArray& mus, NDArray& zhat) {
int n = col0.lengthOf();
int m = permut.lengthOf();
if(m==0) {
zhat.assign(0.);
return;
}
int last = permut.e<int>(m-1);
for (int k = 0; k < n; ++k) {
if (col0.e<T>(k) == (T)0.f)
zhat.p(k, (T)0.f);
else {
else {
T dk = diag.e<T>(k);
T prod = (singVals.e<T>(last) + dk) * (mus.e<T>(last) + (shifts.e<T>(last) - dk));
@ -543,7 +543,7 @@ void SVD<T>::perturb(const NDArray& col0, const NDArray& diag, const NDArray& pe
}
T tmp = math::nd4j_sqrt<T,T>(prod);
zhat.p(k, col0.e<T>(k) > (T)0.f ? tmp : -tmp);
}
}
}
}
@ -552,16 +552,16 @@ void SVD<T>::perturb(const NDArray& col0, const NDArray& diag, const NDArray& pe
template <typename T>
void SVD<T>::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArray& perm, const NDArray& singVals,
const NDArray& shifts, const NDArray& mus, NDArray& U, NDArray& V) {
int n = zhat.lengthOf();
int m = perm.lengthOf();
for (int k = 0; k < n; ++k) {
auto colU = new NDArray(U({0,0, k,k+1}, true));
*colU = 0.;
NDArray* colV = nullptr;
if (_calcV) {
colV = new NDArray(V({0,0, k,k+1}, true));
*colV = 0.;
@ -569,21 +569,21 @@ void SVD<T>::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArra
if (zhat.e<T>(k) == (T)0.f) {
colU->p(k, 1.f);
if (_calcV)
if (_calcV)
colV->p(k, 1.f);
}
else {
for(int l = 0; l < m; ++l) {
int i = perm.e<int>(l);
U.p(i,k, zhat.e<T>(i)/(((diag.e<T>(i) - shifts.e<T>(k)) - mus.e<T>(k)) )/( (diag.e<T>(i) + singVals.e<T>(k))));
}
U.p(n,k, 0.f);
*colU /= colU->reduceNumber(reduce::Norm2);
if (_calcV) {
for(int l = 1; l < m; ++l){
int i = perm.e<T>(l);
V.p(i,k, diag.e<T>(i) * zhat.e<T>(i) / (((diag.e<T>(i) - shifts.e<T>(k)) - mus.e<T>(k)) )/( (diag.e<T>(i) + singVals.e<T>(k))));
@ -592,21 +592,21 @@ void SVD<T>::calcSingVecs(const NDArray& zhat, const NDArray& diag, const NDArra
*colV /= colV->reduceNumber(reduce::Norm2);
}
}
delete colU;
if (_calcV)
delete colU;
if (_calcV)
delete colV;
}
auto colU = U({0,0, n,n+1}, true);
colU = 0.;
colU.p(n, 1.);
colU.p(n, 1.);
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDArray& V) {
const T almostZero = DataTypeUtils::min<T>();
auto col0 = _m({col1, col1+size, col1, col1+1}, true);
auto diagP = _m({col1, col1+size, col1, col1+size}, true).diagonal('c');
@ -616,30 +616,30 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
diag.p(Nd4jLong(0), T(0));
singVals = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
U = NDArrayFactory::create<T>(_u.ordering(), {size+1, size+1}, _u.getContext());
if (_calcV)
if (_calcV)
V = NDArrayFactory::create<T>(_v.ordering(), {size, size}, _v.getContext());
int curSize = size;
while(curSize > 1 && diag.template e<T>(curSize-1) == (T)0.f)
--curSize;
int m = 0;
int m = 0;
std::vector<T> indices;
for(int k = 0; k < curSize; ++k)
if(math::nd4j_abs<T>(col0.template e<T>(k)) > almostZero)
indices.push_back((T)k);
indices.push_back((T)k);
auto permut = NDArrayFactory::create<T>(_m.ordering(), {1, (int)indices.size()}, indices, _m.getContext());
auto shifts = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
auto mus = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
auto zhat = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
calcSingVals(col0, diag, permut, singVals, shifts, mus);
perturb(col0, diag, permut, singVals, shifts, mus, zhat);
calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V);
for(int i=0; i<curSize-1; ++i) {
calcSingVecs(zhat, diag, permut, singVals, shifts, mus, U, V);
for(int i=0; i<curSize-1; ++i) {
if(singVals.e<T>(i) > singVals.e<T>(i+1)) {
T _e0 = singVals.e<T>(i);
T _e1 = singVals.e<T>(i+1);
@ -652,24 +652,24 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
auto temp3 = temp1;
temp1.assign(temp2);
temp2.assign(temp3);
if(_calcV) {
auto temp1 = V({0,0, i,i+1}, true);
auto temp2 = V({0,0, i+1,i+2}, true);
auto temp3 = temp1;
temp1.assign(temp2);
temp2.assign(temp3);
temp2.assign(temp3);
}
}
}
auto temp1 = singVals({0,curSize, 0,0}, true);
for (int e = 0; e < curSize / 2; ++e) {
T tmp = temp1.e<T>(e);
temp1.p(e, temp1.e<T>(curSize-1-e));
temp1.p(curSize-1-e, tmp);
}
}
auto temp2 = U({0,0, 0,curSize}, true);
for(int i = 0; i < curSize/2; ++i) {
auto temp3 = temp2({0,0, i,i+1}, true);
@ -678,7 +678,7 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
temp3.assign(temp4);
temp4.assign(temp5);
}
if (_calcV) {
auto temp2 = V({0,0, 0,curSize}, true);
for(int i = 0; i < curSize/2; ++i) {
@ -688,71 +688,71 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
temp3.assign(temp4);
temp4.assign(temp5);
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shift) {
// requires rows = cols + 1;
const int n = col2 - col1 + 1;
const int k = n/2;
const T almostZero = DataTypeUtils::min<T>();
T alphaK;
T betaK;
T r0;
T lambda, phi, c0, s0;
T betaK;
T r0;
T lambda, phi, c0, s0;
auto l = NDArrayFactory::create<T>(_u.ordering(), {1, k}, _u.getContext());
auto f = NDArrayFactory::create<T>(_u.ordering(), {1, n-k-1}, _u.getContext());
if(n < _switchSize) {
if(n < _switchSize) {
JacobiSVD<T> jac(_m({col1,col1+n+1, col1,col1+n}, true), _calcU, _calcV, _fullUV);
if (_calcU) {
auto temp = _u({col1,col1+n+1, col1,col1+n+1}, true);
temp.assign(jac._u);
}
else {
auto temp1 = _u({0,1, col1,col1+n+1}, true);
temp1.assign(jac._u({0,1, 0,0}, true));
temp1.assign(jac._u({0,1, 0,0}, true));
auto temp2 = _u({1,2, col1,col1+n+1}, true);
temp2.assign(jac._u({n,n+1, 0,0}, true));
}
if (_calcV) {
auto temp = _v({row1W,row1W+n, col1W,col1W+n}, true);
temp.assign(jac._v);
}
auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true);
temp.assign(0.);
temp.assign(0.);
auto diag = _m.diagonal('c');
(*diag)({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true));
delete diag;
delete diag;
return;
}
alphaK = _m.e<T>(col1 + k, col1 + k);
betaK = _m.e<T>(col1 + k + 1, col1 + k);
DivideAndConquer(k + 1 + col1, col2, k + 1 + row1W, k + 1 + col1W, shift);
DivideAndConquer(col1, k - 1 + col1, row1W, col1W + 1, shift + 1);
if (_calcU) {
lambda = _u.e<T>(col1 + k, col1 + k);
phi = _u.e<T>(col1 + k + 1, col2 + 1);
}
}
else {
lambda = _u.e<T>(1, col1 + k);
phi = _u.e<T>(0, col2 + 1);
}
r0 = math::nd4j_sqrt<T, T>((math::nd4j_abs<T>(alphaK * lambda) * math::nd4j_abs<T>(alphaK * lambda)) + math::nd4j_abs<T>(betaK * phi) * math::nd4j_abs<T>(betaK * phi));
if(_calcU) {
l.assign(_u({col1+k, col1+k+1, col1,col1+k}, true));
f.assign(_u({col1+k+1,col1+k+2, col1+k+1,col1+n}, true));
@ -766,10 +766,10 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
// VofSVD.printIndexedBuffer();
// singVals.printIndexedBuffer();
// printf("!! \n");
if (_calcV)
if (_calcV)
_v.p(row1W+k, col1W, 1.f);
if (r0 < almostZero){
c0 = 1.;
s0 = 0.;
@ -778,9 +778,9 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
c0 = alphaK * lambda / r0;
s0 = betaK * phi / r0;
}
if (_calcU) {
auto temp = _u({col1,col1+k+1, col1+k,col1+k+1}, true);
NDArray q1(temp);
@ -794,15 +794,15 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
auto temp2 = _u({col1,col1+k+1, col2+1,col2+2}, true);
temp2.assign(q1 * (-s0));
auto temp3 = _u({col1+k+1,col1+n+1, col1, col1+1}, true);
temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0);
temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0);
auto temp4 =_u({col1+k+1,col1+n+1, col2+1,col2+2}, true);
temp4 *= c0;
}
}
else {
T q1 = _u.e<T>(0, col1 + k);
for (int i = col1 + k - 1; i >= col1; --i)
for (int i = col1 + k - 1; i >= col1; --i)
_u.p(0, i+1, _u.e<T>(0, i));
_u.p(0, col1, q1 * c0);
@ -812,7 +812,7 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
_u({1,2, col1+1, col1+k+1}, true) = 0.f;
_u({0,1, col1+k+1, col1+n}, true) = 0.f;
}
_m.p(col1 + shift, col1 + shift, r0);
auto temp1 = _m({col1+shift+1,col1+shift+k+1, col1+shift,col1+shift+1}, true);
temp1.assign(l*alphaK);
@ -820,21 +820,21 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
temp2.assign(f*betaK);
deflation(col1, col2, k, row1W, col1W, shift);
NDArray UofSVD, VofSVD, singVals;
calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD);
calcBlockSVD(col1 + shift, n, UofSVD, singVals, VofSVD);
if(_calcU) {
auto pTemp = _u({col1, col1+n+1, col1,col1+n+1}, true);
auto temp = pTemp;
pTemp.assign(mmul(temp, UofSVD));
pTemp.assign(mmul(temp, UofSVD));
}
else {
auto pTemp = _u({0,0, col1,col1+n+1}, true);
auto temp = pTemp;
pTemp.assign(mmul(temp, UofSVD));
}
if (_calcV) {
auto pTemp = _v({row1W,row1W+n, row1W,row1W+n}, true);
auto temp = pTemp;
@ -851,29 +851,29 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
//////////////////////////////////////////////////////////////////////////
template<typename T>
void SVD<T>::exchangeUV(const HHsequence& hhU, const HHsequence& hhV, const NDArray& U, const NDArray& V) {
if (_calcU) {
int colsU = _fullUV ? hhU.rows() : _diagSize;
int colsU = _fullUV ? hhU.rows() : _diagSize;
auto temp1 = NDArrayFactory::create<T>(_u.ordering(), {hhU.rows(), colsU}, _u.getContext());
temp1.setIdentity();
_u = temp1;
_u = temp1;
auto temp2 = _u({0,_diagSize, 0,_diagSize}, true);
temp2.assign(V({0,_diagSize, 0,_diagSize}, true));
const_cast<HHsequence&>(hhU).mulLeft(_u);
}
if (_calcV) {
int colsV = _fullUV ? hhV.rows() : _diagSize;
int colsV = _fullUV ? hhV.rows() : _diagSize;
auto temp1 = NDArrayFactory::create<T>(_v.ordering(), {hhV.rows(), colsV}, _v.getContext());
temp1.setIdentity();
_v = temp1;
auto temp2 = _v({0,_diagSize, 0,_diagSize}, true);
temp2.assign(U({0,_diagSize, 0,_diagSize}, true));
const_cast<HHsequence&>(hhV).mulLeft(_v);
const_cast<HHsequence&>(hhV).mulLeft(_v);
}
}
@ -882,41 +882,37 @@ template <typename T>
void SVD<T>::evalData(const NDArray& matrix) {
const T almostZero = DataTypeUtils::min<T>();
if(matrix.sizeAt(1) < _switchSize) {
JacobiSVD<T> jac(matrix, _calcU, _calcV, _fullUV);
if(_calcU)
if(_calcU)
_u = jac._u;
if(_calcV)
if(_calcV)
_v = jac._v;
_s.assign(jac._s);
return;
}
T scale = matrix.reduceNumber(reduce::AMax).e<T>(0);
if(scale == (T)0.)
if(scale == (T)0.)
scale = 1.;
NDArray copy;
if(_transp) {
copy = NDArrayFactory::create<T>(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(0)}, matrix.getContext());
for(int i = 0; i < copy.sizeAt(0); ++i)
for(int j = 0; j < copy.sizeAt(1); ++j)
copy.p<T>(i, j, matrix.e<T>(j,i) / scale);
}
if(_transp)
copy = matrix.transpose();
else
copy = matrix / scale;
BiDiagonalUp biDiag(copy);
_u = 0.;
_v = 0.;
auto temp1 = biDiag._HHbidiag.transpose();
auto temp2 = _m({0,_diagSize, 0,0}, true);
temp2.assign(temp1);
@ -925,21 +921,21 @@ void SVD<T>::evalData(const NDArray& matrix) {
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
temp3.assign(0.);
DivideAndConquer(0, _diagSize - 1, 0, 0, 0);
DivideAndConquer(0, _diagSize - 1, 0, 0, 0);
for (int i = 0; i < _diagSize; ++i) {
T a = math::nd4j_abs<T>(_m.e<T>(i, i));
_s.p(i, a * scale);
if (a < almostZero) {
if (a < almostZero) {
auto temp = _s({i+1,_diagSize, 0,0}, true);
temp.assign(0.);
temp.assign(0.);
break;
}
else if (i == _diagSize-1)
else if (i == _diagSize-1)
break;
}
if(_transp)
if(_transp)
exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u);
else
exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v);

View File

@ -671,6 +671,20 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
return result;
}
//////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) {
if(!shapeInfo)
throw std::runtime_error("ShapeUtils::shapeAsVector method: input shapeInfo must not be nullptr !");
std::vector<Nd4jLong> vector(shapeInfo[0]);
for (uint e = 0; e < shapeInfo[0]; e++)
vector[e] = shapeInfo[e + 1];
return vector;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, nd4j::memory::Workspace* workspace){

View File

@ -323,7 +323,9 @@
(11, TruncatedNormalDistribution) ,\
(12, AlphaDropOut),\
(13, ExponentialDistribution),\
(14, ExponentialDistributionInv)
(14, ExponentialDistributionInv), \
(15, PoissonDistribution), \
(16, GammaDistribution)
#define PAIRWISE_INT_OPS \
(0, ShiftLeft), \

View File

@ -58,8 +58,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -75,7 +75,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
//----- calculation of output -----//
// NHWC: [kH, kW, oC, iC] x [bS, iH, iW, iC] = [kH, kW, oC, bS, iH, iW]
// NCHW: [iC, oC, kH, kW] x [bS, iC, iH, iW] = [oC, kH, kW, bS, iH, iW]
// NCHW: [kH, kW, oC, iC] x [bS, iC, iH, iW] = [kH, kW, oC, bS, iH, iW]
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5});
LaunchContext* ctx = block.launchContext();
helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
@ -103,8 +103,8 @@ DECLARE_SHAPE_FN(deconv2d) {
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
const int rank = 4;
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo));
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
@ -131,10 +131,10 @@ DECLARE_SHAPE_FN(deconv2d) {
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
int oH, oW; // output height, width
ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
@ -196,15 +196,18 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
if(isSameMode){ // SAME
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
// ----- calculation of gradI -> pass it through conv2d_ff ----- //
nd4j::ops::conv2d conv2d;
@ -252,9 +255,9 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
const int rank = 4;
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, weightsShapeInfo[0]);
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo));
REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo));
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
@ -284,10 +287,10 @@ DECLARE_SHAPE_FN(deconv2d_bp) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, oC, iC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
REQUIRE_TRUE(shape::shapeEquals(4, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));

View File

@ -28,13 +28,13 @@
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
@ -52,26 +52,26 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
// create empty conv2d input array
NDArray input(gradO->ordering(), gradIShape->asVectorT<Nd4jLong>(), gradO->dataType(), block.launchContext());
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK();
}
@ -88,11 +88,10 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
auto gradIShapeShapeInfo = inputShape->at(0); // [4]
const int rank = 4;
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DECONV2D_TF OP: rank of input array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
REQUIRE_TRUE(gradIShapeShapeInfo[0] == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, gradIShapeShapeInfo[0]);
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo));
REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV2D_TF OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo));
REQUIRE_TRUE(shape::rank(gradIShapeShapeInfo) == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to %i, but got %i instead !", 1, shape::rank(gradIShapeShapeInfo));
const int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) height
const int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) width
@ -109,28 +108,28 @@ DECLARE_SHAPE_FN(deconv2d_tf) {
if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1;
}
else {
else {
indIOioC = 1; indIiH = 2; indOoH = 2;
}
}
std::vector<Nd4jLong> gradIShape = INPUT_VARIABLE(0)->template asVectorT<Nd4jLong>();
const int bS = gradIShape[0]; // batch size
const int iH = gradIShape[indIiH]; // input height
const int iW = gradIShape[indIiH+1]; // input width
const int iC = gradIShape[indIOioC]; // input channels
const int iC = gradIShape[indIOioC]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels
const int oH = gradOShapeInfo[indOoH+1]; // input height
const int oW = gradOShapeInfo[indOoH+2]; // input width
int trueiH, trueiW; // output height, width
ConvolutionUtils::calcOutSizeDeconv2D(trueiH, trueiW, kH, kW, sH, sW, pH, pW, dH, dW, oH, oW, isSameMode);
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradIShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradIShape).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,trueiH,trueiW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(expectedGradIShape == gradIShape, 0, "CUSTOM DECONV2D_TF OP: wrong shape of array with output shape, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradIShape).c_str());
REQUIRE_TRUE(shape::shapeEquals(4, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
Nd4jLong shape[4];
shape[0] = bS;

View File

@ -59,22 +59,22 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(!isNCDHW)
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
if(isSameMode) //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
//----- calculation of output -----//
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
// NCDHW: [kD, kH, kW, oC, iC] x [bS, iC, iD, iH, iW] = [kD, kH, kW, oC, bS, iD, iH, iW]
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
@ -105,8 +105,8 @@ DECLARE_SHAPE_FN(deconv3d) {
auto biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
const int rank = 5;
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo[0]);
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, shape::rank(weightsShapeInfo));
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height
@ -138,10 +138,10 @@ DECLARE_SHAPE_FN(deconv3d) {
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
REQUIRE_TRUE(shape::rank(biasShapeInfo) <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, shape::rank(biasShapeInfo), shape::length(biasShapeInfo));
int oD, oH, oW; // output depth, height, width
ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
@ -209,15 +209,15 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
int trueoD, trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
// ----- calculation of gradI -> pass it through conv3d_ff ----- //
nd4j::ops::conv3dnew conv3d;
@ -252,7 +252,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
if(!isNCDHW)
delete gradO;
return ND4J_STATUS_OK;
return Status::OK();
}
DECLARE_TYPES(deconv3d_bp) {
@ -272,9 +272,9 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
Nd4jLong* gradOShapeInfo = block.width() > 3 ? inputShape->at(3) : inputShape->at(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
const int rank = 5;
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo[0]);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, weightsShapeInfo[0]);
REQUIRE_TRUE(gradOShapeInfo[0] == rank, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, gradOShapeInfo[0]);
REQUIRE_TRUE(shape::rank(inputShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of input array must be equal to %i, but got %i instead !", rank, shape::rank(inputShapeInfo));
REQUIRE_TRUE(shape::rank(weightsShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of weights array must be equal to %i , but got %i instead !", rank, shape::rank(weightsShapeInfo));
REQUIRE_TRUE(shape::rank(gradOShapeInfo) == rank, 0, "CUSTOM DECONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to %i, but got %i instead !", rank, shape::rank(gradOShapeInfo));
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 1));// filter(kernel) height
@ -309,10 +309,10 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
int trueoD, trueoH, trueoW; // true output depth, height, width
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
REQUIRE_TRUE(shape::shapeEquals(5, expectedGradOShape.data(), shape::rank(gradOShapeInfo), shape::shapeOf(gradOShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of output gradients next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(shape::shapeEquals(5, expectedWeightsShape.data(), shape::rank(weightsShapeInfo), shape::shapeOf(weightsShapeInfo)), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));

View File

@ -69,36 +69,26 @@ DECLARE_TYPES(biasadd) {
////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto bias = INPUT_VARIABLE(1);
auto epsilonNext = INPUT_VARIABLE(2);
auto epsilon = OUTPUT_VARIABLE(0);
auto input = INPUT_VARIABLE(0);
auto bias = INPUT_VARIABLE(1);
auto gradO = INPUT_VARIABLE(2);
auto gradI = OUTPUT_VARIABLE(0);
auto gradB = OUTPUT_VARIABLE(1);
epsilon->assign(epsilonNext);
const bool isNCHW = !block.getBArguments()->empty() ? B_ARG(0) : false;
const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last
// cnn case
if (input->rankOf() == 4) {
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1});
gradI->assign(gradO);
auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1});
gradB->assign(sum);
delete sum;
} else if (input->rankOf() == 2) {
// regular fully-connected case
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});
gradB->assign(sum);
delete sum;
}
gradO->reduceAlongDimension(nd4j::reduce::Sum, gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim}));
return ND4J_STATUS_OK;
}
DECLARE_SYN(BiasAddGrad, biasadd_bp);
////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(biasadd_bp) {
auto input = inputShape->at(0);
auto bias = inputShape->at(1);

View File

@ -623,7 +623,7 @@ namespace nd4j {
//Zero output array, so unused elements have 0 gradient
output->nullify();
std::sort(indices.begin(), indices.end());
if(indices.size() == 3 && (indices[1] - indices[0]) == 1) {
output->p(indices[0], *epsNext);
}

View File

@ -0,0 +1,83 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_random_gamma)
#include <ops/declarable/headers/random.h>
#include <ops/declarable/helpers/random.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(random_gamma, 2, 1, false, 0, 0) {
// gamma distribution
auto rng = block.randomGenerator();
auto shape = INPUT_VARIABLE(0);
auto alpha = INPUT_VARIABLE(1);
NDArray* beta = nullptr;
if (block.width() > 2) {
beta = INPUT_VARIABLE(2);
REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(*alpha, *beta), 0, "random_gamma: alpha and beta shapes should be broadcastable.");
}
auto output = OUTPUT_VARIABLE(0);
auto seed = 0;
if (block.getIArguments()->size()) {
seed = INT_ARG(0);
}
rng.setSeed(seed);
helpers::fillRandomGamma(block.launchContext(), rng, alpha, beta, output);
return Status::OK();
}
DECLARE_SHAPE_FN(random_gamma) {
auto in = INPUT_VARIABLE(0);
auto shape = in->template asVectorT<Nd4jLong>();
auto alphaShape = inputShape->at(1);
auto additionalShape = alphaShape;
if (inputShape->size() > 2) {
auto rest = inputShape->at(2); additionalShape = nullptr;
REQUIRE_TRUE(ShapeUtils::areShapesBroadcastable(alphaShape, rest), 0, "random_gamma: alpha and beta shapes should be broadcastable.");
ShapeUtils::evalBroadcastShapeInfo(alphaShape, rest, true, additionalShape, block.workspace());
}
auto lastDim = shape::sizeAt(alphaShape, 0);
auto dtype = ArrayOptions::dataType(alphaShape);
for (auto i = 0; i < shape::rank(additionalShape); i++)
shape.push_back(shape::sizeAt(additionalShape, i));
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape);
return SHAPELIST(newShape);
}
DECLARE_TYPES(random_gamma) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif

View File

@ -0,0 +1,67 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_random_poisson)
#include <ops/declarable/headers/random.h>
#include <ops/declarable/helpers/random.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(random_poisson, 2, 1, false, 0, 0) {
// gamma distribution
auto rng = block.randomGenerator();
auto shape = INPUT_VARIABLE(0);
auto lambda = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
auto seed = 0;
if (block.getIArguments()->size()) {
seed = INT_ARG(0);
}
rng.setSeed(seed);
helpers::fillRandomPoisson(block.launchContext(), rng, lambda, output);
return Status::OK();
}
DECLARE_SHAPE_FN(random_poisson) {
auto in = INPUT_VARIABLE(0);
auto shape = in->template asVectorT<Nd4jLong>();
auto lambdaShape = inputShape->at(1);
auto dtype = ArrayOptions::dataType(lambdaShape);
for (auto d = 0; d < shape::rank(lambdaShape); ++d ) {
shape.emplace_back(shape::sizeAt(lambdaShape, d));
}
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape);
return SHAPELIST(newShape);
}
DECLARE_TYPES(random_poisson) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif

View File

@ -185,42 +185,42 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
}
else { // bidirectional
// Wx validation
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
}
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),

View File

@ -24,23 +24,28 @@
#include<array>
namespace nd4j {
namespace ops {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
// first of all take into account possible presence of empty arrays
// also if scalar is present -> copy its value to vector with length=1
std::vector<NDArray*> nonEmptyArrs;
std::vector<int> arrsToDelete;
int index = 0;
bool allOfSameType = true;
auto theFirstRank = block.width() > 0?INPUT_VARIABLE(0)->rankOf():0;
auto theFirstDatatype = block.width() > 0?INPUT_VARIABLE(0)->dataType():block.dataType();
for(int i = 0; i < block.width(); ++i) {
auto theFirstRank = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
auto theFirstDatatype = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();
for(int i = 0; i < numOfInArrs; ++i) {
auto input = INPUT_VARIABLE(i);
auto currentRank = input->rankOf();
@ -50,6 +55,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
if(!input->isEmpty()) {
allOfSameType &= (theFirstDatatype == input->dataType());
if(input->rankOf() == 0) {
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
vec->assign(input);
@ -63,25 +69,28 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
}
}
const int numOfArrs = nonEmptyArrs.size();
const int numOfNonEmptyArrs = nonEmptyArrs.size();
if(numOfArrs == 0){
if(numOfNonEmptyArrs == 0){
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty");
return Status::OK();
}
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
int axis = INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank;
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
if(axis < 0){
axis += rank;
}
// ******** input validation ******** //
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
for(int i = 1; i < numOfArrs; ++i)
for(int i = 1; i < numOfNonEmptyArrs; ++i)
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !");
for(int i = 1; i < numOfArrs; ++i) {
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
for(int dim = 0; dim < rank; ++dim)
if(dim != axis)
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
@ -90,7 +99,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
auto output = OUTPUT_VARIABLE(0);
if(numOfArrs == 1)
if(numOfNonEmptyArrs == 1)
output->assign(nonEmptyArrs[0]);
else
helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis);
@ -108,20 +117,25 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
DECLARE_TYPES(concat) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true);
->setAllowedInputTypes(nd4j::DataType::ANY);
// ->setSameMode(true);
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(concat) {
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
// first of all take into account possible presence of empty arrays
// also if scalar is present -> use the shape of vector with length=1 instead
std::vector<Nd4jLong*> arrShapes;
std::vector<int> shapesToDelete;
int index = 0;
for(int i = 0; i < block.width(); ++i) {
for(int i = 0; i < numOfInArrs; ++i) {
if(inputShape->at(i)[0] == 0) {
if (shape::isEmpty(inputShape->at(i)))
@ -135,21 +149,22 @@ DECLARE_SHAPE_FN(concat) {
++index;
}
const int numOfArrs = arrShapes.size();
const int numOfNonEmptyArrs = arrShapes.size();
const int rank = arrShapes[0][0];
int axis = INT_ARG(0);
if(axis < 0)
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
if(axis < 0){
axis += rank;
}
// ******** input validation ******** //
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
for(int i = 1; i < numOfArrs; ++i)
for(int i = 1; i < numOfNonEmptyArrs; ++i)
REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
for(int i = 1; i < numOfArrs; ++i) {
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
for(int dim = 0; dim < rank; ++dim)
if(dim != axis)
REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
@ -161,12 +176,12 @@ DECLARE_SHAPE_FN(concat) {
COPY_SHAPE(arrShapes[0], outShapeInfo);
// case when we have only one input array
if(numOfArrs == 1) {
if(numOfNonEmptyArrs == 1) {
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
return SHAPELIST(CONSTANT(outShapeInfo));
}
for(int i = 1; i < numOfArrs; ++i)
for(int i = 1; i < numOfNonEmptyArrs; ++i)
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
@ -358,55 +373,66 @@ DECLARE_SHAPE_FN(concat) {
// return SHAPELIST(newShape);
// }
DECLARE_TYPES(concat_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) {
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1);
auto first = INPUT_VARIABLE(0);
const int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf());
int startPos = 0;
for (int e = 0; e < numOfInArrs - 1; e++) {
auto originalChunk = INPUT_VARIABLE(e);
auto epsilonChunk = OUTPUT_VARIABLE(e);
std::vector<Nd4jLong> indices(2 * epsilonNext->rankOf());
int width = originalChunk->sizeAt(axis);
for (int e = 0; e < epsilonNext->rankOf(); e++) {
if (e == axis)
indices[2*e + 1] = (indices[2*e] = startPos) + width;
else
indices[2*e + 1] = indices[2*e] = 0;
}
CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 1) {
auto epsilonNext = INPUT_VARIABLE(block.width() - 1);
auto subarray = (*epsilonNext)(indices, true);
epsilonChunk->assign(subarray);
auto first = INPUT_VARIABLE(0);
int axis = INT_ARG(0);
if (axis < 0)
axis += first->rankOf();
int startPos = 0;
for (int e = 0; e < block.width() - 1; e++) {
auto originalChunk = INPUT_VARIABLE(e);
auto epsilonChunk = OUTPUT_VARIABLE(e);
std::vector<Nd4jLong> indices(2 * epsilonNext->rankOf());
int width = originalChunk->sizeAt(axis);
for (int e = 0; e < epsilonNext->rankOf(); e++) {
if (e == axis)
indices[2*e + 1] = (indices[2*e] = startPos) + width;
else
indices[2*e + 1] = indices[2*e] = 0;
}
auto subarray = (*epsilonNext)(indices, true);
epsilonChunk->assign(subarray);
startPos += width;
}
return ND4J_STATUS_OK;
}
DECLARE_SHAPE_FN(concat_bp) {
auto shapeList = SHAPELIST();
for (int e = 0; e < inputShape->size() - 1; e++) {
auto inShape = inputShape->at(e);
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
}
return shapeList;
}
startPos += width;
}
return ND4J_STATUS_OK;
}
DECLARE_TYPES(concat_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(concat_bp) {
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
auto shapeList = SHAPELIST();
for (int e = 0; e < numOfInArrs - 1; e++) {
auto inShape = inputShape->at(e);
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
}
return shapeList;
}
}
}

View File

@ -49,7 +49,23 @@ namespace nd4j {
DECLARE_CUSTOM_OP(random_exponential, 1, 1, true, 1, 0);
#endif
#if NOT_EXCLUDED(OP_random_crop)
DECLARE_CUSTOM_OP(random_crop, 2, 1, false, 0, 0);
#endif
/**
* random_gamma op.
*/
#if NOT_EXCLUDED(OP_random_gamma)
DECLARE_CUSTOM_OP(random_gamma, 2, 1, false, 0, 0);
#endif
/**
* random_poisson op.
*/
#if NOT_EXCLUDED(OP_random_poisson)
DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0);
#endif
}
}

View File

@ -59,8 +59,8 @@ namespace nd4j {
DECLARE_CONFIGURABLE_OP(invert_permutation, 1, 1, false, 0, 0);
#endif
DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 1);
DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 1);
DECLARE_CUSTOM_OP(concat, -1, 1, false, 0, 0);
DECLARE_CUSTOM_OP(concat_bp, -1, -1, false, 0, 0);
#if NOT_EXCLUDED(OP_mergemax)
DECLARE_OP(mergemax, -1, 1, false);

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <ops/declarable/helpers/random.h>
//#include <vector>
#include <memory>
//#include <graph/Context.h>
#include <ShapeUtils.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
Nd4jLong* broadcasted = nullptr;
if (beta != nullptr)
ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace());
else
broadcasted = alpha->shapeInfo();
auto step = shape::length(broadcasted);
auto shift = output->lengthOf() / step;
auto copyAlpha = alpha;
auto copyBeta = beta;
if (beta != nullptr) {
NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context);
NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context);
copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha));
copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta));
}
// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c';
bool directOutput = output->ews() == 1 && output->ordering() == 'c';
T* outputBuf = output->dataBuffer()->primaryAsT<T>();
PRAGMA_OMP_PARALLEL_FOR
for (auto k = 0; k < shift; k++) {
auto pos = k * step;
auto u = rng.relativeT<T>(k, 0., 1.);
for (auto e = 0; e < step; e++)
if (directOutput) {
outputBuf[pos + e] = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
beta != nullptr ? copyBeta->t<T>(e) * u : u);
}
else {
output->t<T>(pos + e) = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
beta != nullptr ? copyBeta->t<T>(e) * u : u);
}
}
if (beta != nullptr) {
delete copyAlpha;
delete copyBeta;
//delete broadcasted;
}
}
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE);
}
BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context,
graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE);
/*
* algorithm Poisson generator based upon the inversion by sequential search:[48]:505
init:
Let x 0, p eλ, s p.
Generate uniform random number u in [0,1].
while u > s do:
x x + 1.
p p * λ / x.
s s + p.
return x.
* */
template <typename T>
void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
auto shift = output->lengthOf() / lambda->lengthOf();
auto step = lambda->lengthOf();
T* lambdaBuf = lambda->dataBuffer()->primaryAsT<T>();
T* outputBuf = output->dataBuffer()->primaryAsT<T>();
bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c';
bool directOut = output->ews() == 1 && output->ordering() == 'c';
PRAGMA_OMP_PARALLEL_FOR
for (auto k = 0; k < shift; k++) {
auto pos = k * step;
auto u = rng.relativeT<T>(k, 0., 1.);
for (auto e = 0; e < step; e++) {
auto p = math::nd4j_exp<T, T>(-lambda->t<T>(e));
auto s = p;
auto x = T(0.f);
while (u > s) {
x += 1.f;
p *= directLa?lambdaBuf[e]/x:lambda->t<T>(e) / x;
s += p;
}
if (directOut)
outputBuf[pos + e] = x;
else
output->t<T>(pos + e) = x;
}
}
}
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE);
}
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context,
graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES);
}
}
}

View File

@ -37,7 +37,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
const int rows = matrix.sizeAt(0);
const int cols = matrix.sizeAt(1);
if(cols > rows) {
_transp = true;
@ -52,7 +52,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
_switchSize = switchSize;
_calcU = calcU;
_calcV = calcV;
_fullUV = fullUV;
_fullUV = fullUV;
if (_transp)
math::nd4j_swap<bool>(_calcU, _calcV);
@ -63,7 +63,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
if (_calcU)
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
else
else
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
_u.assign(0.);
@ -84,7 +84,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
const int rows = matrix.sizeAt(0);
const int cols = matrix.sizeAt(1);
if(cols > rows) {
_transp = true;
@ -99,7 +99,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
_switchSize = switchSize;
_calcU = calcU;
_calcV = calcV;
_fullUV = fullUV;
_fullUV = fullUV;
if (_transp)
math::nd4j_swap<bool>(_calcU, _calcV);
@ -110,7 +110,7 @@ SVD<T>::SVD(const NDArray& matrix, const int switchSize, const bool calcU, const
if (_calcU)
_u = NDArrayFactory::create<T>(matrix.ordering(), {_diagSize + 1, _diagSize + 1}, matrix.getContext());
else
else
_u = NDArrayFactory::create<T>(matrix.ordering(), {2, _diagSize + 1}, matrix.getContext());
_u.assign(0.);
@ -128,13 +128,13 @@ void SVD<T>::deflation1(int col1, int shift, int ind, int size) {
if(ind <= 0)
throw std::runtime_error("ops::helpers::SVD::deflation1 method: input int must satisfy condition ind > 0 !");
int first = col1 + shift;
int first = col1 + shift;
T cos = _m.e<T>(first, first);
T sin = _m.e<T>(first+ind, first);
T denom = math::nd4j_sqrt<T, T>(cos*cos + sin*sin);
if (denom == (T)0.) {
_m.p(first+ind, first+ind, 0.f);
return;
}
@ -145,14 +145,14 @@ void SVD<T>::deflation1(int col1, int shift, int ind, int size) {
_m.p(first,first, denom);
_m.p(first+ind, first, 0.f);
_m.p(first+ind, first+ind, 0.f);
auto rotation = NDArrayFactory::create<T>(_m.ordering(), {2, 2}, _m.getContext());
rotation.p(0, 0, cos);
rotation.p(0, 1, -sin);
rotation.p(1, 0, sin);
rotation.p(1, 1, cos);
if (_calcU) {
if (_calcU) {
auto temp = _u({col1,col1+size+1, 0,0}, true);
JacobiSVD<T>::mulRotationOnRight(col1, col1+ind, temp, rotation);
}
@ -466,7 +466,7 @@ void SVD<T>::calcSingVals(const NDArray& col0, const NDArray& diag, const NDArra
useBisection = true;
if (shift == right && (muCur < -(right - left) || muCur > (T)0.))
useBisection = true;
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev))
if (math::nd4j_abs<T>(fCur) > math::nd4j_abs<T>(fPrev) && math::nd4j_abs<T>(fCur - fPrev) > (T)16. * DataTypeUtils::eps<T>())
useBisection = true;
}
@ -900,12 +900,8 @@ void SVD<T>::evalData(const NDArray& matrix) {
scale = 1.;
NDArray copy;
if(_transp) {
copy = NDArrayFactory::create<T>(matrix.ordering(), {matrix.sizeAt(1), matrix.sizeAt(0)}, matrix.getContext());
for(int i = 0; i < copy.sizeAt(0); ++i)
for(int j = 0; j < copy.sizeAt(1); ++j)
copy.p<T>(i, j, matrix.e<T>(j,i) / scale);
}
if(_transp)
copy = matrix.transpose();
else
copy = matrix / scale;
@ -934,8 +930,8 @@ void SVD<T>::evalData(const NDArray& matrix) {
else if (i == _diagSize-1)
break;
}
if(_transp)
if(_transp)
exchangeUV(biDiag.makeHHsequence('v'), biDiag.makeHHsequence('u'), _v, _u);
else
exchangeUV(biDiag.makeHHsequence('u'), biDiag.makeHHsequence('v'), _u, _v);
@ -954,20 +950,20 @@ static void svd_(const NDArray* x, const std::vector<NDArray*>& outArrs, const b
auto u = outArrs[1];
auto v = outArrs[2];
const int rank = x->rankOf();
const int sRank = rank - 1;
const int rank = x->rankOf();
const int sRank = rank - 1;
auto listX = x->allTensorsAlongDimension({rank-2, rank-1});
auto listS = s->allTensorsAlongDimension({sRank-1});
ResultSet* listU(nullptr), *listV(nullptr);
if(calcUV) {
if(calcUV) {
listU = u->allTensorsAlongDimension({rank-2, rank-1});
listV = v->allTensorsAlongDimension({rank-2, rank-1});
}
for(int i = 0; i < listX->size(); ++i) {
// NDArray<T> matrix(x->ordering(), {listX->at(i)->sizeAt(0), listX->at(i)->sizeAt(1)}, block.getContext());
// matrix.assign(listX->at(i));
helpers::SVD<T> svdObj(*(listX->at(i)), switchNum, calcUV, calcUV, fullUV);
@ -976,12 +972,12 @@ static void svd_(const NDArray* x, const std::vector<NDArray*>& outArrs, const b
if(calcUV) {
listU->at(i)->assign(svdObj._u);
listV->at(i)->assign(svdObj._v);
}
}
}
delete listX;
delete listS;
if(calcUV) {
delete listU;
delete listV;

View File

@ -0,0 +1,186 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <ops/declarable/helpers/random.h>
//#include <NativeOps.h>
#include <vector>
#include <memory>
#include <graph/Context.h>
#include <helpers/RandomLauncher.h>
#include <ShapeUtils.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
namespace helpers {
/*
* fillGammaKernel - fill up output with gamma distributed values
*
* uList - uniformly distributed values set
* uLength - length of uList
* alpha - alpha param
* beta - beta param
* output - distributed output.
* */
template <typename T>
static __global__ void fillGammaKernel(T* uList, Nd4jLong uLength, T* alpha, Nd4jLong* alphaShape,
T* beta, Nd4jLong* betaShape, T* output, Nd4jLong* outputShape) {
// fill up
__shared__ Nd4jLong aLength;
if (threadIdx.x == 0) {
aLength = shape::length(alphaShape);
}
__syncthreads();
for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) {
auto pos = k * aLength;
auto u = uList[k]; // this is a vector
for (auto e = threadIdx.x; e < (int)aLength; e += blockDim.x) {
auto aIndex = shape::getIndexOffset(e, alphaShape);
auto bIndex = betaShape?shape::getIndexOffset(e, betaShape):-1LL;
auto betaV = T(beta != nullptr ? beta[bIndex] * u : u);
auto zIndex = shape::getIndexOffset(e + pos, outputShape);
output[zIndex] = math::nd4j_igamma<T, T, T>(alpha[aIndex], betaV);
}
}
}
template <typename T>
static void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
// To fill up output need to broadcast alpha and beta to the same shape and in
Nd4jLong* broadcasted = nullptr;
if (beta != nullptr)
ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcasted, context->getWorkspace());
else
broadcasted = alpha->shapeInfo();
auto step = shape::length(broadcasted);
auto shift = output->lengthOf() / step;
auto copyAlpha = alpha;
auto copyBeta = beta;
if (beta != nullptr) {
NDArray alphaBroadcasted(broadcasted, alpha->dataType(), true, context);
NDArray betaBroadcasted(broadcasted, beta->dataType(), true, context);
copyAlpha = (alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), alpha));
copyBeta = (betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), beta));
copyAlpha->tickWriteDevice(); copyBeta->tickWriteDevice();
}
auto stream = context->getCudaStream();
NDArray uniform = NDArrayFactory::create<T>('c', {shift}, context);
uniform.syncToDevice();
// fill up uniform with given length
RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.);
fillGammaKernel<T><<<128, 128, 256, *stream>>>(uniform.dataBuffer()->specialAsT<T>(), shift,
copyAlpha->dataBuffer()->specialAsT<T>(), copyAlpha->specialShapeInfo(),
beta?copyBeta->dataBuffer()->specialAsT<T>():(T*)nullptr,
beta?copyBeta->specialShapeInfo():(Nd4jLong*)nullptr,
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
if (beta != nullptr) {
delete copyAlpha;
delete copyBeta;
//delete broadcasted;
}
}
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
if (beta)
NDArray::prepareSpecialUse({output}, {alpha, beta});
else
NDArray::prepareSpecialUse({output}, {alpha});
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE);
if (beta)
NDArray::registerSpecialUse({output}, {alpha, beta});
else
NDArray::prepareSpecialUse({output}, {alpha});
}
BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE);
/*
* algorithm Poisson generator based upon the inversion by sequential search
*
init:
Let x ← 0, p ← eλ, s ← p.
using uniformly random sequence U (u in U) distributed at [0, 1].
while u > s do:
x ← x + 1.
p ← p * λ / x.
s ← s + p.
return x.
* */
template <typename T>
static __global__ void fillPoissonKernel(T* uList, Nd4jLong uLength, T* lambda, Nd4jLong* lambdaShape, T* output,
Nd4jLong* outputShape) {
__shared__ Nd4jLong step;
if (threadIdx.x == 0) {
step = shape::length(lambdaShape);
}
__syncthreads();
for (auto k = blockIdx.x; k < (int)uLength; k += gridDim.x) {
auto pos = k * step;
auto u = uList[k];
for (auto e = threadIdx.x; e < step; e += blockDim.x) {
auto p = math::nd4j_exp<T,T>(-lambda[e]);
auto s = p;
auto x = T(0.f);
auto lIndex = shape::getIndexOffset(e, lambdaShape);
auto zIndex = shape::getIndexOffset(e + pos, outputShape);
while (u > s) {
x += T(1.);
p *= lambda[lIndex] / x;
s += p;
}
output[zIndex] = x;
}
}
}
template <typename T>
static void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
auto shift = output->lengthOf() / lambda->lengthOf();
NDArray uniform('c', {shift}, output->dataType());
auto stream = context->getCudaStream();
// fill up uniform with given length
RandomLauncher::fillUniform(context, rng, &uniform, 0., 1.);
fillPoissonKernel<T><<<128, 256, 128, *stream>>>(uniform.dataBuffer()->specialAsT<T>(), uniform.lengthOf(),
lambda->dataBuffer()->specialAsT<T>(), lambda->specialShapeInfo(),
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
}
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
NDArray::prepareSpecialUse({output}, {lambda});
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {lambda});
}
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE);
}
}
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
//
// Declaration of distribution helpers
//
#ifndef __RANDOM_HELPERS__
#define __RANDOM_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
#include <helpers/helper_random.h>
#include <graph/Context.h>
namespace nd4j {
namespace ops {
namespace helpers {
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
}
}
}
#endif

View File

@ -268,8 +268,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// dLdO
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer());
const bool dLdOReorder = op_bp_prim_desc.diff_src_desc() != dLdO_user_mem.get_desc();
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdO_user_mem;
const bool dLdOReorder = op_bp_prim_desc.diff_dst_desc() != dLdO_user_mem.get_desc();
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdO_user_mem;
if (dLdOReorder)
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem;
@ -284,8 +284,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// dLdI
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer());
const bool dLdIReorder = op_bp_prim_desc.diff_dst_desc() != dLdI_user_mem.get_desc();
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdI_user_mem;
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem;
// gamma and beta (and their gradients) if they are present

View File

@ -29,125 +29,340 @@
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
const int isNCHW) {
namespace nd4j {
namespace ops {
namespace platforms {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
//////////////////////////////////////////////////////////////////////
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
const int isNCHW) {
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr,
&conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr,
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine,
const_cast<NDArray *>(bias)->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
}
stream.wait();
}
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
PLATFORM_IMPL(conv2d) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr,
&conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr,
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto output = OUTPUT_VARIABLE(
0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW);
return Status::OK();
}
PLATFORM_CHECK(conv2d) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
// conv2d is only available for float32 dtype
return block.isUseMKLDNN() && input->dataType() == nd4j::DataType::FLOAT32 &&
weights->dataType() == nd4j::DataType::FLOAT32;
}
}
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine,
const_cast<NDArray *>(bias)->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
}
stream.wait();
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW);
return Status::OK();
}
PLATFORM_CHECK(conv2d) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
// conv2d is only available for float32 dtype
return block.isUseMKLDNN() && input->dataType() == nd4j::DataType::FLOAT32 &&
weights->dataType() == nd4j::DataType::FLOAT32;
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
gradO->rankOf());
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc =
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
conv_weights_md, conv_dst_md, conv_strides, conv_dilation,
conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
};
return Status::OK();
}
PLATFORM_CHECK(conv2d_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
}
}
}
}

View File

@ -1,243 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(conv2d_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 4, 0,
"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",
gradO->rankOf());
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH);
if (isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc =
convolution_backward_data::desc(algorithm::convolution_auto, conv_diff_src_md,
conv_weights_md, conv_dst_md, conv_strides,
conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
};
return Status::OK();
}
PLATFORM_CHECK(conv2d_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
}
}
}
}

View File

@ -29,139 +29,373 @@
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(conv3dnew) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
namespace nd4j {
namespace ops {
namespace platforms {
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
if (isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
nullptr, bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr,
&conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr,
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
}
stream.wait();
return Status::OK();
}
PLATFORM_CHECK(conv3dnew) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
}
}
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
empty);
mkldnn::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNCDHW,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, weights,
nullptr, bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr,
&conv_bias_md, &conv_dst_md,
&user_src_md, nullptr, &user_weights_md, nullptr,
&user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = bias != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
auto user_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto conv_src_memory = user_src_memory;
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
conv_src_memory = mkldnn::memory(conv_prim_desc.src_desc(), engine);
reorder(user_src_memory, conv_src_memory).execute(stream, user_src_memory, conv_src_memory);
}
auto conv_weights_memory = user_weights_memory;
if (conv_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
conv_weights_memory = mkldnn::memory(conv_prim_desc.weights_desc(), engine);
reorder(user_weights_memory, conv_weights_memory).execute(stream, user_weights_memory,
conv_weights_memory);
}
auto conv_dst_memory = user_dst_memory;
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
conv_dst_memory = mkldnn::memory(conv_prim_desc.dst_desc(), engine);
}
if (bias != nullptr) {
auto conv_bias_memory = mkldnn::memory(conv_prim_desc.bias_desc(), engine, bias->buffer());
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_BIAS, conv_bias_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
} else {
convolution_forward(conv_prim_desc).execute(stream, {{MKLDNN_ARG_SRC, conv_src_memory},
{MKLDNN_ARG_WEIGHTS, conv_weights_memory},
{MKLDNN_ARG_DST, conv_dst_memory}});
}
if (conv_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(conv_dst_memory, user_dst_memory).execute(stream, conv_dst_memory, user_dst_memory);
}
stream.wait();
return Status::OK();
}
PLATFORM_CHECK(conv3dnew) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(
0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, output});
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv3dnew_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
gradO->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNDHWC =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
dW, iD, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNDHWC,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
gradW, gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r, conv_dilation);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
conv_diff_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_dilation, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
}
return Status::OK();
}
PLATFORM_CHECK(conv3dnew_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
}
}
}
}

View File

@ -1,263 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author saudet
// @author raver119@gmail.com
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(conv3dnew_bp) {
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of input array must be equal to 5, but got %i instead !",
input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of weights array must be equal to 5, but got %i instead !",
weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 5, 0,
"CUSTOM CONV3D_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !",
gradO->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
int isNDHWC =
block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNDHWC, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
indIOioC, indIOioD, indWiC, indWoC, indWkD);
int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH,
dW, iD, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
{bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
"CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !",
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0,
"CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !",
expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0,
"CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !",
oC, bias->rankOf(), bias->lengthOf());
mkldnn_memory_desc_t empty;
mkldnn::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
conv_diff_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
mkldnn::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
mkldnn::memory::dims conv_strides, conv_padding, conv_padding_r;
mkldnnUtils::getMKLDNNMemoryDescConv3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode,
isNDHWC,
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, weights,
gradW, gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
&conv_diff_weights_md, &conv_bias_md, &conv_dst_md,
&user_src_md, &user_diff_src_md, &user_weights_md,
&user_diff_weights_md, &user_bias_md, &user_dst_md,
conv_strides, conv_padding, conv_padding_r);
auto conv_desc = gradB != nullptr
? convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r)
: convolution_forward::desc(prop_kind::forward,
algorithm::convolution_auto, conv_src_md,
conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, mkldnnUtils::getEngine(
LaunchContext::defaultContext()->engine()));
if (gradW != nullptr) {
auto convW_desc = gradB != nullptr
? convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r)
: convolution_backward_weights::desc(
algorithm::convolution_auto, conv_src_md, conv_diff_weights_md,
conv_dst_md, conv_strides, conv_padding, conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convW_prim_desc = convolution_backward_weights::primitive_desc(convW_desc, engine,
conv_prim_desc);
auto userW_src_memory = mkldnn::memory(user_src_md, engine,
const_cast<NDArray *>(input)->buffer());
auto userW_weights_memory = mkldnn::memory(user_diff_weights_md, engine, gradW->buffer());
auto userW_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convW_src_memory = userW_src_memory;
if (convW_prim_desc.src_desc() != userW_src_memory.get_desc()) {
convW_src_memory = mkldnn::memory(convW_prim_desc.src_desc(), engine);
reorder(userW_src_memory, convW_src_memory).execute(stream, userW_src_memory,
convW_src_memory);
}
auto convW_weights_memory = userW_weights_memory;
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
convW_weights_memory = mkldnn::memory(convW_prim_desc.diff_weights_desc(), engine);
}
auto convW_dst_memory = userW_dst_memory;
if (convW_prim_desc.diff_dst_desc() != userW_dst_memory.get_desc()) {
convW_dst_memory = mkldnn::memory(convW_prim_desc.diff_dst_desc(), engine);
reorder(userW_dst_memory, convW_dst_memory).execute(stream, userW_dst_memory,
convW_dst_memory);
}
if (gradB != nullptr) {
auto convW_bias_memory = mkldnn::memory(convW_prim_desc.diff_bias_desc(), engine,
gradB->buffer());
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory},
{MKLDNN_ARG_DIFF_BIAS, convW_bias_memory}});
} else {
convolution_backward_weights(convW_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, convW_src_memory},
{MKLDNN_ARG_DIFF_DST, convW_dst_memory},
{MKLDNN_ARG_DIFF_WEIGHTS, convW_weights_memory}});
}
if (convW_prim_desc.diff_weights_desc() != userW_weights_memory.get_desc()) {
reorder(convW_weights_memory, userW_weights_memory).execute(stream, convW_weights_memory,
userW_weights_memory);
}
stream.wait();
}
if (gradI != nullptr) {
auto convI_desc = convolution_backward_data::desc(algorithm::convolution_auto,
conv_diff_src_md, conv_weights_md,
conv_dst_md, conv_strides, conv_padding,
conv_padding_r);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto convI_prim_desc = convolution_backward_data::primitive_desc(convI_desc, engine,
conv_prim_desc);
auto userI_src_memory = mkldnn::memory(user_diff_src_md, engine, gradI->buffer());
auto userI_weights_memory = mkldnn::memory(user_weights_md, engine,
const_cast<NDArray *>(weights)->buffer());
auto userI_dst_memory = mkldnn::memory(user_dst_md, engine,
const_cast<NDArray *>(gradO)->buffer());
auto convI_src_memory = userI_src_memory;
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
convI_src_memory = mkldnn::memory(convI_prim_desc.diff_src_desc(), engine);
}
auto convI_weights_memory = userI_weights_memory;
if (convI_prim_desc.weights_desc() != userI_weights_memory.get_desc()) {
convI_weights_memory = mkldnn::memory(convI_prim_desc.weights_desc(), engine);
reorder(userI_weights_memory, convI_weights_memory).execute(stream, userI_weights_memory,
convI_weights_memory);
}
auto convI_dst_memory = userI_dst_memory;
if (convI_prim_desc.diff_dst_desc() != userI_dst_memory.get_desc()) {
convI_dst_memory = mkldnn::memory(convI_prim_desc.diff_dst_desc(), engine);
reorder(userI_dst_memory, convI_dst_memory).execute(stream, userI_dst_memory,
convI_dst_memory);
}
convolution_backward_data(convI_prim_desc).execute(stream,
{{MKLDNN_ARG_DIFF_DST, convI_dst_memory},
{MKLDNN_ARG_WEIGHTS, convI_weights_memory},
{MKLDNN_ARG_DIFF_SRC, convI_src_memory}});
if (convI_prim_desc.diff_src_desc() != userI_src_memory.get_desc()) {
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory,
userI_src_memory);
}
stream.wait();
}
return Status::OK();
}
PLATFORM_CHECK(conv3dnew_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
auto input = INPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(
1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB});
}
}
}
}

View File

@ -0,0 +1,535 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int isSameMode) {
// input [bS, iH, iW, iC] nchw, mkl doesn't support format nhwc
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
// bias [oC], may be nullptr
// output [bS, oH, oW, oC] nchw, mkl doesn't support format nhwc
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims dilation = { dH - 1, dW - 1};
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
// input type
mkldnn::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32;
else if(input->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16;
else if(input->dataType() == DataType::UINT8)
xType = mkldnn::memory::data_type::u8;
else
xType = mkldnn::memory::data_type::s8;
// weights type
mkldnn::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8;
// output and bias type (have the same types)
mkldnn::memory::data_type zType;
if(output->dataType() == DataType::FLOAT32)
zType = mkldnn::memory::data_type::f32;
else if(output->dataType() == DataType::HALF)
zType = mkldnn::memory::data_type::f16;
else if(output->dataType() == DataType::UINT8)
zType = mkldnn::memory::data_type::u8;
else if(output->dataType() == DataType::INT8)
zType = mkldnn::memory::data_type::s8;
else
zType = mkldnn::memory::data_type::s32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// bias
mkldnn::memory::desc b_mkl_md;
if(bias != nullptr)
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
// output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
mkldnn::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
// bias
if(bias != nullptr) {
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
}
// output
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[MKLDNN_ARG_DST] = z_mkl_mem;
// run calculations
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
static void deconv2dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int isSameMode) {
// input and gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
// gradB [oC], may be nullptr
// gradO [bS, oH, oW, oC]
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims dilation = { dH - 1, dW - 1 };
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
// input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradW type
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradB type
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
// gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
// gradW
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
// gradB
mkldnn::memory::desc gradB_mkl_md;
if(gradB != nullptr)
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// backward weights primitive description
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
mkldnn::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
// gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
// gradW
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
// gradB
if(gradB != nullptr) {
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
}
// run backward data calculations
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// run backward weights calculations
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(deconv2d) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(isSameMode){ // SAME
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
// mkl supports only [oC, iC, kH, kW] format for weights
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
// mkl supports only NCHW
if(!isNCHW) {
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
}
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode);
delete weights;
if(!isNCHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(deconv2d) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
auto output = INPUT_VARIABLE(0);
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
const DataType zType = output->dataType();
const DataType bType = bias != nullptr ? bias->dataType() : zType;
return block.isUseMKLDNN() && (
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
);
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(deconv2d_bp) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of weights array must be equal to 4 , but got %i instead !", weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "CUSTOM DECONV2D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 4, but got %i instead !", gradO->rankOf());
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, oC, iC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(isSameMode){ // SAME
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
// mkl supports only [oC, iC, kH, kW] for weights
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
gradW = new NDArray(gradW->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
// mkl supports NCHW format only
if(!isNCHW) {
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
}
deconv2dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode);
delete weights;
delete gradW;
if(!isNCHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(deconv2d_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
const DataType gradOType = gradO->dataType();
const DataType gradIType = gradI->dataType();
const DataType gradWType = gradW->dataType();
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
}
}
}
}

View File

@ -0,0 +1,244 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
// gradO [bS, oH, oW, oC]
mkldnn::memory::dims strides = { sH, sW };
mkldnn::memory::dims dilation = { dH - 1, dW - 1 };
mkldnn::memory::dims padding = { pH, pW };
mkldnn::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
// weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::nchw; // isNCHW ? mkldnn::memory::format_tag::nchw : mkldnn::memory::format_tag::nhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oihw;
mkldnn::memory::dims xDims = {bS, iC, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, gradOType, mkldnn::memory::format_tag::any);
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
// gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
// gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description
mkldnn::convolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description
mkldnn::convolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
mkldnn::stream stream(engine);
// provide memory buffers and check whether reorder is required
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
// gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
// run backward data calculations
mkldnn::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(deconv2d_tf) {
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width
int pH = INT_ARG(4); // paddings height
int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 1-NHWC, 0-NCHW
const int rank = gradO->rankOf();
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
int indIOioC, indIiH, indWoC(3), indOoH;
if(!isNCHW) {
indIOioC = 3; indIiH = 1; indOoH = 1;
}
else {
indIOioC = 1; indIiH = 2; indOoH = 2;
}
std::vector<Nd4jLong> gradIShapeVector = gradIShape->template asVectorT<Nd4jLong>();
const int bS = gradIShapeVector[0]; // batch size
const int iH = gradIShapeVector[indIiH]; // input height
const int iW = gradIShapeVector[indIiH+1]; // input width
const int iC = gradIShapeVector[indIOioC]; // input channels
const int oC = weights->sizeAt(indWoC); // output channels
const int oH = gradO->sizeAt(indOoH); // input height
const int oW = gradO->sizeAt(indOoH); // input width
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// mkl supports only [oC, iC, kH, kW] for weights
weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
// mkl supports NCHW format only
if(!isNCHW) {
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
}
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW);
delete weights;
if(!isNCHW) {
delete gradI;
delete gradO;
}
// ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK();
}
PLATFORM_CHECK(deconv2d_tf) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCDHW), gradI
const DataType wType = weights->dataType();
const DataType gradOType = gradO->dataType();
const DataType gradIType = gradI->dataType();
return block.isUseMKLDNN() && ((wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16));
}
}
}
}

View File

@ -0,0 +1,549 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
const int isSameMode) {
// input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
// bias [oC], may be nullptr
// output [bS, oD, oH, oW, oC] ncdhw, mkl doesn't support format ndhwc
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
mkldnn::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1};
mkldnn::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
// input type
mkldnn::memory::data_type xType;
if(input->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32;
else if(input->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16;
else if(input->dataType() == DataType::UINT8)
xType = mkldnn::memory::data_type::u8;
else
xType = mkldnn::memory::data_type::s8;
// weights type
mkldnn::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8;
// output and bias type (have the same types)
mkldnn::memory::data_type zType;
if(output->dataType() == DataType::FLOAT32)
zType = mkldnn::memory::data_type::f32;
else if(output->dataType() == DataType::HALF)
zType = mkldnn::memory::data_type::f16;
else if(output->dataType() == DataType::UINT8)
zType = mkldnn::memory::data_type::u8;
else if(output->dataType() == DataType::INT8)
zType = mkldnn::memory::data_type::s8;
else
zType = mkldnn::memory::data_type::s32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
// bias
mkldnn::memory::desc b_mkl_md;
if(bias != nullptr)
b_mkl_md = mkldnn::memory::desc({oC}, zType, mkldnn::memory::format_tag::x);
// output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(zDims, zType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(zDims, zType, xFormat);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
z_user_md.data.format_desc.blocking.strides[4] = output->stridesOf()[4];
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// operation primitive description
mkldnn::deconvolution_forward::desc op_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct,
x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_prim_desc(op_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
mkldnn::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
// bias
if(bias != nullptr) {
auto b_mkl_mem = mkldnn::memory(b_mkl_md, engine, bias->getBuffer());
args[MKLDNN_ARG_BIAS] = b_mkl_mem;
}
// output
auto z_user_mem = mkldnn::memory(z_user_md, engine, output->getBuffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[MKLDNN_ARG_DST] = z_mkl_mem;
// run calculations
mkldnn::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
const int isSameMode) {
// input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
// gradB [oC], may be nullptr
// gradO [bS, oD, oH, oW, oC]
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
mkldnn::memory::dims strides = { sD, sH, sW };
mkldnn::memory::dims dilation = { dD - 1, dH - 1, dW - 1 };
mkldnn::memory::dims padding = { pD, pH, pW };
mkldnn::memory::dims padding_r = {(iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
// input type
mkldnn::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// weights type
mkldnn::memory::data_type wType = weights->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradO type
mkldnn::memory::data_type gradOType = gradO->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradI type
mkldnn::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradW type
mkldnn::memory::data_type gradWType = gradW->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16;
// gradB type
mkldnn::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? mkldnn::memory::data_type::f32 : mkldnn::memory::data_type::bf16) : mkldnn::memory::data_type::f32;
mkldnn::memory::format_tag xFormat = mkldnn::memory::format_tag::ncdhw; // isNCDHW ? mkldnn::memory::format_tag::ncdhw : mkldnn::memory::format_tag::ndhwc;
mkldnn::memory::format_tag wFormat = mkldnn::memory::format_tag::oidhw;
mkldnn::memory::dims xDims = {bS, iC, iD, iH, iW};
mkldnn::memory::dims wDims = {oC, iC, kD, kH, kW};
mkldnn::memory::dims zDims = {bS, oC, oD, oH, oW};
// memory descriptors for arrays
// input
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(xDims, xType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(xDims, xType, xFormat);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
// weights
mkldnn::memory::desc w_mkl_md = mkldnn::memory::desc(wDims, wType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc w_user_md = mkldnn::memory::desc(wDims, wType, wFormat);
w_user_md.data.format_kind = mkldnn_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
// gradO
mkldnn::memory::desc gradO_mkl_md = mkldnn::memory::desc(zDims, gradOType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradO_user_md = mkldnn::memory::desc(zDims, gradOType, xFormat);
gradO_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
// gradI
mkldnn::memory::desc gradI_mkl_md = mkldnn::memory::desc(xDims, gradIType, mkldnn::memory::format_tag::any);
mkldnn::memory::desc gradI_user_md = mkldnn::memory::desc(xDims, gradIType, xFormat);
gradI_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
// gradW
mkldnn::memory::desc gradW_mkl_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
mkldnn::memory::desc gradW_user_md = mkldnn::memory::desc(wDims, gradWType, wFormat);
gradW_user_md.data.format_kind = mkldnn_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
// gradB
mkldnn::memory::desc gradB_mkl_md;
if(gradB != nullptr)
gradB_mkl_md = mkldnn::memory::desc({oC}, gradBType, mkldnn::memory::format_tag::x);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward primitive description
mkldnn::deconvolution_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, mkldnn::algorithm::deconvolution_direct, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// backward data primitive description
mkldnn::deconvolution_backward_data::desc op_data_bp_desc(mkldnn::algorithm::deconvolution_direct, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
// backward weights primitive description
mkldnn::deconvolution_backward_weights::desc op_weights_bp_desc(mkldnn::algorithm::deconvolution_direct, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
mkldnn::deconvolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
mkldnn::stream stream(engine);
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = mkldnn::memory(x_user_md, engine, input->getBuffer());
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
// weights
auto w_user_mem = mkldnn::memory(w_user_md, engine, weights->getBuffer());
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
auto w_mkl_mem = wReorder ? mkldnn::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
if (wReorder)
mkldnn::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
// gradO
auto gradO_user_mem = mkldnn::memory(gradO_user_md, engine, gradO->getBuffer());
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
auto gradO_mkl_mem = gradOReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
if (gradOReorder)
mkldnn::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = gradO_mkl_mem;
// gradI
auto gradI_user_mem = mkldnn::memory(gradI_user_md, engine, gradI->getBuffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? mkldnn::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = gradI_mkl_mem;
// gradW
auto gradW_user_mem = mkldnn::memory(gradW_user_md, engine, gradW->getBuffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? mkldnn::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[MKLDNN_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
// gradB
if(gradB != nullptr) {
auto gradB_mkl_mem = mkldnn::memory(gradB_mkl_md, engine, gradB->getBuffer());
args[MKLDNN_ARG_DIFF_BIAS] = gradB_mkl_mem;
}
// run backward data calculations
mkldnn::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// run backward weights calculations
mkldnn::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
mkldnn::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
mkldnn::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(deconv3d) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2)); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(isSameMode){ // SAME
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
}
// mkl supports only [oC, iC, kD, kH, kW] format for weights
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
// mkl supports only NCDHW
if(!isNCDHW) {
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
}
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode);
delete weights;
if(!isNCDHW) {
delete input;
delete output;
}
return Status::OK();
}
PLATFORM_CHECK(deconv3d) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
auto output = INPUT_VARIABLE(0);
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
const DataType zType = output->dataType();
const DataType bType = bias != nullptr ? bias->dataType() : zType;
return block.isUseMKLDNN() && (
(xType==DataType::FLOAT32 && wType==DataType::FLOAT32 && bType==DataType::FLOAT32 && zType==DataType::FLOAT32) ||
(xType==DataType::HALF && wType==DataType::HALF && bType==DataType::HALF && zType==DataType::HALF ) ||
((xType==DataType::UINT8 || xType==DataType::INT8) && wType==DataType::INT8 && (zType==DataType::UINT8 || zType==DataType::INT8 || zType==DataType::INT32 || zType==DataType::FLOAT32) && bType == zType)
);
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(deconv3d_bp) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of weights array must be equal to 5 , but got %i instead !", weights->rankOf());
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "CUSTOM DECONV3D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead !", gradO->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
int trueoD, trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizeDeconv3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, oC, iC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
// mkl supports only [oC, iC, kD, kH, kW] for weights
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
gradW = new NDArray(gradW->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
// mkl supports NCDHW format only
if(!isNCDHW) {
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
}
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isSameMode);
delete weights;
delete gradW;
if(!isNCDHW) {
delete input;
delete gradI;
delete gradO;
}
return Status::OK();
}
PLATFORM_CHECK(deconv3d_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
const DataType xType = input->dataType();
const DataType wType = weights->dataType();
const DataType gradOType = gradO->dataType();
const DataType gradIType = gradI->dataType();
const DataType gradWType = gradW->dataType();
const DataType gradBType = gradB != nullptr ? gradB->dataType() : DataType::FLOAT32;
return block.isUseMKLDNN() && ((xType==DataType::FLOAT32 || xType==DataType::BFLOAT16) && (wType==DataType::FLOAT32 || wType==DataType::BFLOAT16) && (gradOType==DataType::FLOAT32 || gradOType==DataType::BFLOAT16) && (gradIType==DataType::FLOAT32 || gradIType==DataType::BFLOAT16) && (gradWType==DataType::FLOAT32 || gradWType==DataType::BFLOAT16) && (gradBType==DataType::FLOAT32 || gradBType==DataType::BFLOAT16) );
}
}
}
}

View File

@ -416,36 +416,36 @@ PLATFORM_IMPL(lstmLayer) {
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
}
else { // bidirectional
// Wx validation
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
}
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip)};

View File

@ -148,14 +148,15 @@ namespace nd4j {
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oH, oW };
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_strides = { sH, sW };
conv_padding = { pH, pW };
conv_dilation = { dH-1, dW-1};
conv_padding_r = { (oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };
@ -227,14 +228,15 @@ namespace nd4j {
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r) {
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation) {
mkldnn::memory::dims conv_src_tz = { bS, iC, iD, iH, iW };
mkldnn::memory::dims conv_weights_tz = { oC, iC, kD, kH, kW };
mkldnn::memory::dims conv_bias_tz = { oC };
mkldnn::memory::dims conv_dst_tz = { bS, oC, oD, oH, oW };
conv_strides = { sD, sH, sW };
conv_padding = { pD, pH, pW };
conv_strides = { sD, sH, sW };
conv_dilation = { dD-1, dH-1, dW-1};
conv_padding = { pD, pH, pW };
conv_padding_r = { (oD - 1) * sD - iD + kD - pD,
(oH - 1) * sH - iH + kH - pH,
(oW - 1) * sW - iW + kW - pW };

View File

@ -67,6 +67,16 @@ namespace nd4j{
DECLARE_PLATFORM(batchnorm_bp);
DECLARE_PLATFORM(lstmLayer);
DECLARE_PLATFORM(deconv2d);
DECLARE_PLATFORM(deconv2d_tf);
DECLARE_PLATFORM(deconv3d);
DECLARE_PLATFORM(deconv2d_bp);
DECLARE_PLATFORM(deconv3d_bp);
}
}
@ -83,7 +93,7 @@ namespace nd4j{
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
@ -93,7 +103,7 @@ namespace nd4j{
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r, mkldnn::memory::dims& conv_dilation);
void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,

View File

@ -129,6 +129,47 @@ namespace randomOps {
}
};
template <typename T>
class PoissonDistribution {
public:
no_exec_special
no_exec_special_cuda
method_XY
random_def T op(Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
T lambda = extraParams[0];
T x = helper->relativeT(idx, -nd4j::DataTypeUtils::template max<T>() / 10 , nd4j::DataTypeUtils::template max<T>() / 10);
return x <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igammac<T,T,T>(nd4j::math::nd4j_floor<T,T>(x), lambda);
}
random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
T lambda = extraParams[0];
return valueX <= (T)0.f ? (T)0.f : (T)nd4j::math::nd4j_igammac<T,T,T>(nd4j::math::nd4j_floor<T,T>(valueX), lambda);
}
};
template <typename T>
class GammaDistribution {
public:
no_exec_special
no_exec_special_cuda
method_XY
random_def T op(Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
T alpha = extraParams[0];
T beta = extraParams[1];
T x = helper->relativeT(idx, -nd4j::DataTypeUtils::template max<T>() / 10 , nd4j::DataTypeUtils::template max<T>() / 10);
return x <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igamma<T,T,T>(alpha, x * beta);
}
random_def T op(T valueX, Nd4jLong idx, Nd4jLong length, nd4j::graph::RandomGenerator *helper, T *extraParams) {
T alpha = extraParams[0];
T beta = extraParams[1];
return valueX <= (T)0.f ? (T)0.f : nd4j::math::nd4j_igamma<T,T,T>(alpha, beta * valueX);
}
};
/**
* Basic DropOut/DropConnect Op

View File

@ -894,6 +894,10 @@ namespace nd4j {
Z aim = nd4j_pow<X, X, Z>(x, a) / (nd4j_exp<X, Z>(x) * nd4j_gamma<Y, Z>(a));
auto sum = Z(0.);
auto denom = Z(1.);
if (a <= X(0.000001))
//throw std::runtime_error("Cannot calculate gamma for a zero val.");
return Z(0);
for (int i = 0; Z(1./denom) > Z(1.0e-12); i++) {
denom *= (a + i);
sum += nd4j_pow<X, int, Z>(x, i) / denom;

View File

@ -30,7 +30,7 @@ endif()
if (CMAKE_BUILD_TYPE STREQUAL "Release")
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11 -fmax-errors=2")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
else()
@ -38,13 +38,13 @@ if (CMAKE_BUILD_TYPE STREQUAL "Release")
endif()
else()
if (APPLE)
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2 -D__APPLE_OS__=true")
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true")
elseif(WIN32)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
set(CMAKE_CXX_FLAGS " -O0 -g --fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
set(CMAKE_CXX_FLAGS " -O0 -g --fPIC -std=c++11 -fmax-errors=2")
endif()
else()
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -fmax-errors=2")
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 -fmax-errors=2")
if (CPU_BLAS)
SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address")
endif()

View File

@ -437,58 +437,38 @@ TYPED_TEST(TypedConvolutionTests1, sconv2d_3) {
}
TYPED_TEST(TypedConvolutionTests1, deconv2D_FF_NoBias_1) {
Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
TypeParam _expB[] = {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0,};
NDArray exp(_expB, _expS);
auto input = NDArrayFactory::create_<TypeParam>('c', {2, 3, 4, 4});
auto weights = NDArrayFactory::create_<TypeParam>('c', {3, 3, 5, 5});
int bS=2, iH=4,iW=4, iC=3,oC=3, kH=5,kW=5, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=8,oW=8;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
input->linspace(1);
weights->linspace(1);
weights->permutei({2,3,1,0});
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, oC, iC}, {1., 76., 151., 26., 101., 176., 51., 126., 201., 2., 77., 152., 27., 102., 177., 52., 127., 202., 3., 78., 153., 28., 103., 178., 53., 128., 203.,
4., 79., 154., 29., 104., 179., 54., 129., 204., 5., 80., 155., 30., 105., 180., 55., 130., 205., 6., 81., 156., 31., 106., 181., 56., 131., 206.,
7., 82., 157., 32., 107., 182., 57., 132., 207., 8., 83., 158., 33., 108., 183., 58., 133., 208., 9., 84., 159., 34., 109., 184., 59., 134., 209.,
10., 85., 160., 35., 110., 185., 60., 135., 210., 11., 86., 161., 36., 111., 186., 61., 136., 211., 12., 87., 162., 37., 112., 187., 62., 137., 212.,
13., 88., 163., 38., 113., 188., 63., 138., 213., 14., 89., 164., 39., 114., 189., 64., 139., 214., 15., 90., 165., 40., 115., 190., 65., 140., 215.,
16., 91., 166., 41., 116., 191., 66., 141., 216., 17., 92., 167., 42., 117., 192., 67., 142., 217., 18., 93., 168., 43., 118., 193., 68., 143., 218.,
19., 94., 169., 44., 119., 194., 69., 144., 219., 20., 95., 170., 45., 120., 195., 70., 145., 220., 21., 96., 171., 46., 121., 196., 71., 146., 221.,
22., 97., 172., 47., 122., 197., 72., 147., 222., 23., 98., 173., 48., 123., 198., 73., 148., 223., 24., 99., 174., 49., 124., 199., 74., 149., 224.,
25., 100., 175.,50., 125., 200.,75., 150., 225.});
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, input);
variableSpace->putVariable(-2, weights);
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}, {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0});
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2});
block->getIArguments()->push_back(5);
block->getIArguments()->push_back(5);
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
// dilation
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(1);
// NOT same mode
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
input.linspace(1);
nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
Nd4jStatus status = op.execute(block);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_EQ(ND4J_STATUS_OK, status);
auto output = variableSpace->getVariable(1)->getNDArray();
auto output = results->at(0);
ASSERT_TRUE(exp.isSameShape(output));
// exp.printBuffer("Expctd buffer");
//output->printBuffer("Result buffer");
ASSERT_TRUE(exp.equalsTo(output));
delete variableSpace;
delete block;
delete results;
}
TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) {
@ -812,61 +792,54 @@ TEST_F(ConvolutionTests1, Test_im2col_col2im_3) {
TEST_F(ConvolutionTests1, TestDeconv_bp_1) {
int bS=3, iH=4,iW=4, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=4,oW=4;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
double _expb[] = { 35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f, 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f, 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, 131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f, 302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f, 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, 236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, 547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f, 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f};
std::shared_ptr<DataBuffer> pBuffer1 = std::make_shared<DataBuffer>(_expb, sizeof(_expb), nd4j::DataType::DOUBLE, false);
NDArray expEpsilon(pBuffer1, 'c', {3, 3, 4, 4});
double _expwb[] = { 160008.f, 203400.f, 191112.f, 246792.f, 222216.f, 290184.f};
std::shared_ptr<DataBuffer> pBuffer2 = std::make_shared<DataBuffer>(_expwb, sizeof(_expwb), nd4j::DataType::DOUBLE, false);
NDArray expGradW(pBuffer2, 'c', {3, 2, 1, 1});
expGradW.permutei({2,3,1,0});
NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
NDArray bias('c', {oC}, nd4j::DataType::FLOAT32);
NDArray weights('c',{kH,kW,oC,iC}, {1,3,5,2,4,6}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {bS, oC, oH, oW},nd4j::DataType::FLOAT32);
double _expbb[] = {1944.f, 2712.f};
std::shared_ptr<DataBuffer> pBuffer3 = std::make_shared<DataBuffer>(_expbb, sizeof(_expbb), nd4j::DataType::DOUBLE, false);
NDArray expGradB(pBuffer3, 'c', {1, 2});
auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
auto bias = NDArrayFactory::create<double>('c', {1, 2});
auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
auto epsilon = NDArrayFactory::create<double>('c', {3, 2, 4, 4});
/*
Input shape (3, 3, 4, 4)
Weights shape (3, 2, 1, 1)
Epsilon shape (3, 2, 4, 4)
*/
NDArray expGradI('c', {bS, iC, iH, iW}, {35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f,
77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f,
176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f,
131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f,
302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f,
481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f,
236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f,
547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f,
866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f}, nd4j::DataType::FLOAT32);
NDArray expGradW('c', {kH, kW, oC, iC}, {160008., 191112., 222216., 203400., 246792., 290184.f}, nd4j::DataType::FLOAT32);
NDArray expGradB('c', {oC}, {1944.f, 2712.f}, nd4j::DataType::FLOAT32);
input.linspace(1);
weights.linspace(1);
bias.linspace(1);
epsilon.linspace(1);
weights.permutei({2,3,1,0});
gradO.linspace(1);
nd4j::ops::deconv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto gradI = results->at(0);
auto gradW = results->at(1);
auto gradB = results->at(2);
auto expNext = result->at(0);
ASSERT_TRUE(expEpsilon.isSameShape(expNext));
ASSERT_TRUE(expEpsilon.equalsTo(expNext));
auto gradW = result->at(1);
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
auto gradB = result->at(2);
ASSERT_TRUE(expGradB.isSameShape(gradB));
ASSERT_TRUE(expGradB.equalsTo(gradB));
delete result;
delete results;
}
TEST_F(ConvolutionTests1, TestDeconv_bp_2) {
/*
Input shape:
@ -914,13 +887,11 @@ TEST_F(ConvolutionTests1, TestDeconv_ff_2) {
NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.});
auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
auto weights = NDArrayFactory::create<double>('c',{1, 1, 2, 3}, {1,3,5,2,4,6});
auto bias = NDArrayFactory::create<double>('c', {2});
input.linspace(1);
weights.linspace(1);
bias.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op;
@ -2337,14 +2308,14 @@ TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test1) {
int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=3,oW=3;
int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int iH=3,iW=3;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, oC});
auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, oC, iC});
auto exp = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75,

View File

@ -575,24 +575,38 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test1) {
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<double>('c', {bS, oD, oH, oW, oC});
auto weights = NDArrayFactory::create<double>('c', {kD, kH, kW, iC, oC});
auto bias = NDArrayFactory::create<double>('c', {iC});
auto gradO = NDArrayFactory::create<double>('c', {bS, iD, iH, iW, iC});
auto input = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC});
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC});
auto bias = NDArrayFactory::create<float>('c', {iC});
auto gradO = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, nd4j::DataType::FLOAT32);
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, nd4j::DataType::FLOAT32);
NDArray expGradB('c', {iC}, {364.5}, nd4j::DataType::FLOAT32);
input = 0.5;
weights.linspace(0.1, 0.1);
gradO.linspace(0.5);
const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
nd4j::ops::deconv3d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
nd4j::ops::deconv3d opFF;
nd4j::ops::deconv3d_bp opBP;
auto gradI = results->at(0);
auto gradW = results->at(1);
auto gradB = results->at(2);
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(isGradCorrect);
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
ASSERT_TRUE(expGradB.isSameShape(gradB));
ASSERT_TRUE(expGradB.equalsTo(gradB));
delete results;
}
//////////////////////////////////////////////////////////////////////
@ -603,23 +617,32 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test2) {
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<double>('c', {bS, oD, oH, oW, oC});
auto weights = NDArrayFactory::create<double>('c', {kD, kH, kW, iC, oC});
auto gradO = NDArrayFactory::create<double>('c', {bS, iD, iH, iW, iC});
auto input = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC});
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC});
auto gradO = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, 5.6, 2.9, 4.3, 0.75, 1.5}, nd4j::DataType::FLOAT32);
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, nd4j::DataType::FLOAT32);
input = 0.5;
weights.linspace(0.1, 0.1);
gradO.linspace(0.5);
const OpArgsHolder argsHolderFF({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
const OpArgsHolder argsHolderBP({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
nd4j::ops::deconv3d_bp op;
auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
nd4j::ops::deconv3d opFF;
nd4j::ops::deconv3d_bp opBP;
auto gradI = results->at(0);
auto gradW = results->at(1);
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(isGradCorrect);
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
//////////////////////////////////////////////////////////////////////
@ -630,24 +653,31 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test3) {
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<double>('c', {bS, oC, oD, oH, oW});
auto weights = NDArrayFactory::create<double>('c', {oC, iC, kD, kH, kW});
auto gradO = NDArrayFactory::create<double>('c', {bS, iC, iD, iH, iW});
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, nd4j::DataType::FLOAT32);
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, nd4j::DataType::FLOAT32);
input = 0.5;
weights.linspace(0.1, 0.1);
gradO.linspace(0.5);
weights.permutei({2, 3, 4, 1, 0});
const OpArgsHolder argsHolderFF({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
const OpArgsHolder argsHolderBP({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
nd4j::ops::deconv3d_bp op;
auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
nd4j::ops::deconv3d opFF;
nd4j::ops::deconv3d_bp opBP;
auto gradI = results->at(0);
auto gradW = results->at(1);
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(isGradCorrect);
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
//////////////////////////////////////////////////////////////////////
@ -658,24 +688,31 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test4) {
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<double>('c', {bS, oC, oD, oH, oW});
auto weights = NDArrayFactory::create<double>('c', {oC, iC, kD, kH, kW});
auto gradO = NDArrayFactory::create<double>('c', {bS, iC, iD, iH, iW});
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, nd4j::DataType::FLOAT32);
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, nd4j::DataType::FLOAT32);
input = 0.5;
weights.linspace(0.1, 0.1);
gradO.linspace(0.5);
weights.permutei({2, 3, 4, 1, 0});
const OpArgsHolder argsHolderFF({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
const OpArgsHolder argsHolderBP({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
nd4j::ops::deconv3d_bp op;
auto results = op.execute({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
nd4j::ops::deconv3d opFF;
nd4j::ops::deconv3d_bp opBP;
auto gradI = results->at(0);
auto gradW = results->at(1);
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(isGradCorrect);
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
//////////////////////////////////////////////////////////////////////

View File

@ -37,21 +37,6 @@ public:
}
};
TEST_F(DeclarableOpsTests11, test_mixed_biasadd_1) {
if (!Environment::getInstance()->isExperimentalBuild())
return;
auto x = NDArrayFactory::create<double>('c', {2, 3});
auto y = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
auto z = NDArrayFactory::create<float>('c', {2, 3});
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f});
nd4j::ops::biasadd op;
auto status = op.execute({&x, &y}, {&z}, {}, {}, {true});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp, z);
}
TEST_F(DeclarableOpsTests11, test_listdiff_1) {
auto x = NDArrayFactory::create<int>('c', {4}, {0, 1, 2, 3});

View File

@ -243,7 +243,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_12) {
}
TEST_F(DeclarableOpsTests4, Test_BiasAdd_NHWC_1) {
TEST_F(DeclarableOpsTests4, biasadd_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3, 2});
auto bias = NDArrayFactory::create<double>('c', {2}, {1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f});
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests4, Test_BiasAdd_NHWC_1) {
delete result;
}
TEST_F(DeclarableOpsTests4, Test_BiasAdd_NCHW_1) {
TEST_F(DeclarableOpsTests4, biasadd_2) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 3});
auto bias = NDArrayFactory::create<double>('c', {2}, {1, 2});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2});
@ -279,6 +279,95 @@ TEST_F(DeclarableOpsTests4, Test_BiasAdd_NCHW_1) {
delete result;
}
TEST_F(DeclarableOpsTests4, biasadd_3) {
auto x = NDArrayFactory::create<double>('c', {2, 3});
auto row = NDArrayFactory::create<double>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
nd4j::ops::biasadd op;
auto result = op.execute({&x, &row}, {}, {}, {true}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, biasadd_bp_1) {
NDArray x('c', {2,2,2,3}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {2,2,2,3}, nd4j::DataType::FLOAT32);
NDArray bias('c', {3}, {-1., -2, -3}, nd4j::DataType::FLOAT32);
NDArray expGradB('c', {3}, {9.2, 10. , 10.8}, nd4j::DataType::FLOAT32);
gradO.linspace(0.1, 0.1);
nd4j::ops::biasadd_bp op;
auto result = op.execute({&x, &bias, &gradO}, {}, {}, {false}); // NHWC
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto gradI = result->at(0);
auto gradB = result->at(1);
ASSERT_TRUE(gradI->isSameShape(gradO));
ASSERT_TRUE(gradI->equalsTo(gradO));
ASSERT_TRUE(gradB->isSameShape(expGradB));
ASSERT_TRUE(gradB->equalsTo(expGradB));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, biasadd_bp_2) {
NDArray x('c', {2,3,2,2}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {2,3,2,2}, nd4j::DataType::FLOAT32);
NDArray bias('c', {3}, {-1., -2, -3}, nd4j::DataType::FLOAT32);
NDArray expGradB('c', {3}, {6.8, 10., 13.2}, nd4j::DataType::FLOAT32);
gradO.linspace(0.1, 0.1);
nd4j::ops::biasadd_bp op;
auto result = op.execute({&x, &bias, &gradO}, {}, {}, {true}); // NCHW
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto gradI = result->at(0);
auto gradB = result->at(1);
ASSERT_TRUE(gradI->isSameShape(gradO));
ASSERT_TRUE(gradI->equalsTo(gradO));
ASSERT_TRUE(gradB->isSameShape(expGradB));
ASSERT_TRUE(gradB->equalsTo(expGradB));
delete result;
}
TEST_F(DeclarableOpsTests4, biasadd_4) {
if (!Environment::getInstance()->isExperimentalBuild())
return;
auto x = NDArrayFactory::create<double>('c', {2, 3});
auto y = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
auto z = NDArrayFactory::create<float>('c', {2, 3});
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f});
nd4j::ops::biasadd op;
auto status = op.execute({&x, &y}, {&z}, {}, {}, {true});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp, z);
}
TEST_F(DeclarableOpsTests4, Test_Fill_1) {
auto x = NDArrayFactory::create<int>('c', {1, 3}, {3, 2, 4});
auto v = NDArrayFactory::create<double>(2.);
@ -639,24 +728,6 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) {
delete result;
}
TEST_F(DeclarableOpsTests4, Test_BiasAdd_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3});
auto row = NDArrayFactory::create<double>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
nd4j::ops::biasadd op;
auto result = op.execute({&x, &row}, {}, {}, {true}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
TEST_F(DeclarableOpsTests4, Test_1D_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3});

View File

@ -241,6 +241,52 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
int zero = 0;
auto matrix = NDArrayFactory::create<double>('c', {5, 4});
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto grad = NDArrayFactory::create<double>('c', {5,4});
matrix.linspace(1);
grad.linspace(1);
nd4j::ops::strided_slice_bp op;
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
//ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
int zero = 0;
auto matrix = NDArrayFactory::create<double>('c', {1, 2});
// auto b = NDArrayFactory::create<int>('c', {1}, {zero});
// auto e = NDArrayFactory::create<int>('c', {1}, {zero});
// auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto grad = NDArrayFactory::create<double>('c', {1}, {1.});
matrix.linspace(1);
//grad.linspace(1);
nd4j::ops::strided_slice_bp op;
auto result = op.execute({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
//ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});

View File

@ -756,6 +756,27 @@ TEST_F(DeclarableOpsTests9, concat_test24) {
ASSERT_EQ(e, z);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test25) {
auto x0 = NDArrayFactory::create<double>('c', {1,4}, {1,2,3,4});
auto x1 = NDArrayFactory::create<double>('c', {1,4}, {5,6,7,8});
auto axis = NDArrayFactory::create<double>('c', {1}, {0.});
auto exp = NDArrayFactory::create<double>('c', {2,4}, {1,2,3,4,5,6,7,8});
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &axis}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) {

View File

@ -773,6 +773,88 @@ TEST_F(RNGTests, Test_ExponentialDistribution_2) {
delete result;
}
TEST_F(RNGTests, Test_PoissonDistribution_1) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
auto la = NDArrayFactory::create<float>('c', {2, 3});
auto exp0 = NDArrayFactory::create<float>('c', {10, 2, 3});
la.linspace(1.0);
nd4j::ops::random_poisson op;
auto result = op.execute({&x, &la}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Poisson distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_GammaDistribution_1) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
auto al = NDArrayFactory::create<float>('c', {2, 3});
auto exp0 = NDArrayFactory::create<float>('c', {10, 2, 3});
al.linspace(1.0);
nd4j::ops::random_gamma op;
auto result = op.execute({&x, &al}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_GammaDistribution_2) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
auto al = NDArrayFactory::create<float>('c', {2, 3});
auto be = NDArrayFactory::create<float>('c', {2, 3});
auto exp0 = NDArrayFactory::create<float>('c', {10, 2, 3});
al.linspace(1.0);
be.assign(1.0);
nd4j::ops::random_gamma op;
auto result = op.execute({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
TEST_F(RNGTests, Test_GammaDistribution_3) {
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
auto al = NDArrayFactory::create<float>('c', {3, 1});
auto be = NDArrayFactory::create<float>('c', {1, 2});
auto exp0 = NDArrayFactory::create<float>('c', {10, 3, 2});
al.linspace(1.0);
be.assign(2.0);
nd4j::ops::random_gamma op;
auto result = op.execute({&x, &al, &be}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Gamma distribution");
ASSERT_TRUE(exp0.isSameShape(z));
ASSERT_FALSE(exp0.equalsTo(z));
delete result;
}
namespace nd4j {
namespace tests {
static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) {

View File

@ -109,22 +109,22 @@ endif()
# -fsanitize=address
# -fsanitize=leak
if (APPLE)
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -D__APPLE_OS__=true")
set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -D__APPLE_OS__=true")
elseif(WIN32)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
set(CMAKE_CXX_FLAGS " -g -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations -Wa,-mbig-obj")
set(CMAKE_CXX_FLAGS " -g -fPIC -std=c++11 -Wa,-mbig-obj")
endif()
else()
if ("${_RELEASE}" OR CMAKE_BUILD_TYPE STREQUAL "Release")
message("Release build for tests")
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations")
set(CMAKE_CXX_FLAGS "-O3 -fPIC -std=c++11")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
endif()
else()
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 -fassociative-math -funsafe-math-optimizations")
set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 ")
if (NOT CUDA_BLAS)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address")
endif()

View File

@ -975,8 +975,8 @@ public class DifferentialFunctionFactory {
return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
}
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
return new BiasAddGrad(sameDiff(), input, bias, grad).outputVariables();
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) {
return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables();
}
public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) {

View File

@ -109,6 +109,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DTF.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DTF.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class,

View File

@ -45,12 +45,14 @@ public class BiasAdd extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[] {input, bias}, false);
bArguments.clear();
bArguments.add(nchw);
this.nchw = nchw;
}
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
super(new INDArray[]{input, bias}, wrapOrNull(output));
bArguments.clear();
bArguments.add(nchw);
this.nchw = nchw;
}
@Override
@ -80,7 +82,7 @@ public class BiasAdd extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> gradient){
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0)));
return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw));
}
@Override

View File

@ -31,9 +31,12 @@ import java.util.Collections;
import java.util.List;
public class BiasAddGrad extends DynamicCustomOp {
protected boolean nchw = true;
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient) {
public BiasAddGrad(SameDiff sameDiff, SDVariable input, SDVariable bias, SDVariable gradient, boolean nchw) {
super(null, sameDiff, new SDVariable[]{input, bias, gradient});
this.nchw = nchw;
addBArgument(nchw);
}
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient, INDArray output){
@ -52,8 +55,6 @@ public class BiasAddGrad extends DynamicCustomOp {
return "biasadd_bp";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Differentiation not supported for op " + getClass().getSimpleName());

View File

@ -147,69 +147,12 @@ public class DeConv3D extends DynamicCustomOp {
return config.getValue(property);
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
int sD, sH, sW, dD=1, dH=1, dW=1;
val aStrides = nodeDef.getAttrOrThrow("strides");
List<Long> tfStrides = aStrides.getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
List<Long> tfDilation = null;
if (attributesForNode.containsKey("dilations")) {
tfDilation = attributesForNode.get("dilations").getList().getIList(); //[mb,c,d,h,w] or [mb,d,h,w,c] depending on format. mb/c are always 1
}
val aPadding = nodeDef.getAttrOrDefault("padding", null);
String paddingMode = aPadding.getS().toStringUtf8();
String dataFormat = "NDHWC";
if (nodeDef.containsAttr("data_format")) {
val attr = nodeDef.getAttrOrThrow("data_format");
dataFormat = attr.getS().toStringUtf8().toLowerCase();
}
if(dataFormat.equalsIgnoreCase("NCDHW")){
sD = tfStrides.get(2).intValue();
sH = tfStrides.get(3).intValue();
sW = tfStrides.get(4).intValue();
if(tfDilation != null){
dD = tfDilation.get(2).intValue();
dH = tfDilation.get(3).intValue();
dW = tfDilation.get(4).intValue();
}
} else {
sD = tfStrides.get(1).intValue();
sH = tfStrides.get(2).intValue();
sW = tfStrides.get(3).intValue();
if(tfDilation != null){
dD = tfDilation.get(1).intValue();
dH = tfDilation.get(2).intValue();
dW = tfDilation.get(3).intValue();
}
}
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
this.config = DeConv3DConfig.builder()
.kD(-1).kH(-1).kW(-1) //Infer from kernel
.sD(sD).sH(sW).sW(sH)
.dD(dD).dH(dH).dW(dW)
.isSameMode(isSameMode)
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
.build();
addArgs();
}
@Override
public String opName() {
return "deconv3d";
}
@Override
public String tensorflowName() {
return "Conv3DBackpropInputV2";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
@ -224,4 +167,4 @@ public class DeConv3D extends DynamicCustomOp {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}
}

View File

@ -0,0 +1,208 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* DeConv3D operation, TF-wrapper
*/
@Slf4j
@Getter
@NoArgsConstructor
public class DeConv3DTF extends DynamicCustomOp {
protected DeConv3DConfig config;
public DeConv3DTF(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, @NonNull SDVariable weights, @NonNull SDVariable input, @NonNull DeConv3DConfig config) {
super(sameDiff, new SDVariable[]{shape, weights, input});
this.config = config;
addArgs();
}
@Override
public long[] iArgs() {
if (iArguments.size() == 0)
addArgs();
return super.iArgs();
}
@Override
public Map<String, Object> propertiesForFunction() {
if(config == null && !iArguments.isEmpty()){
config = DeConv3DConfig.builder()
.kD(iArguments.get(0))
.kH(iArguments.get(1))
.kW(iArguments.get(2))
.sD(iArguments.get(3))
.sH(iArguments.get(4))
.sW(iArguments.get(5))
.pD(iArguments.get(6))
.pH(iArguments.get(7))
.pW(iArguments.get(8))
.dD(iArguments.get(9))
.dH(iArguments.get(10))
.dW(iArguments.get(11))
.isSameMode(iArguments.get(12) == 1)
.dataFormat(iArguments.get(13) == 1 ? DeConv3DConfig.NDHWC : DeConv3DConfig.NCDHW)
.build();
}
return config.toProperties();
}
private void addArgs() {
addIArgument(config.getKD());
addIArgument(config.getKH());
addIArgument(config.getKW());
addIArgument(config.getSD());
addIArgument(config.getSH());
addIArgument(config.getSW());
addIArgument(config.getPD());
addIArgument(config.getPH());
addIArgument(config.getPW());
addIArgument(config.getDD());
addIArgument(config.getDH());
addIArgument(config.getDW());
addIArgument(ArrayUtil.fromBoolean(config.isSameMode()));
addIArgument(config.getDataFormat().equalsIgnoreCase(DeConv3DConfig.NCDHW) ? 0 : 1);
}
@Override
public boolean isConfigProperties() {
return true;
}
@Override
public String configFieldName() {
return "config";
}
@Override
public Object getValue(Field property) {
if (config == null) {
config = DeConv3DConfig.builder().build();
}
return config.getValue(property);
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val aStrides = nodeDef.getAttrOrThrow("strides");
val aDilations = nodeDef.getAttrOrDefault("dilations", null);
val tfStrides = aStrides.getList().getIList();
val tfDilation = aDilations == null ? null : aDilations.getList().getIList();
int sD, sH, sW, dD, dH, dW;
val aPadding = nodeDef.getAttrOrDefault("padding", null);
String paddingMode = aPadding.getS().toStringUtf8();
String dataFormat = DeConv3DConfig.NDHWC;
if (nodeDef.containsAttr("data_format")) {
val attr = nodeDef.getAttrOrThrow("data_format");
dataFormat = attr.getS().toStringUtf8().toLowerCase();
}
if (dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW)) {
sD = tfStrides.get(2).intValue();
sH = tfStrides.get(3).intValue();
sW = tfStrides.get(4).intValue();
dD = tfDilation == null ? 1 : tfDilation.get(2).intValue();
dH = tfDilation == null ? 1 : tfDilation.get(3).intValue();
dW = tfDilation == null ? 1 : tfDilation.get(4).intValue();
} else {
sD = tfStrides.get(1).intValue();
sH = tfStrides.get(2).intValue();
sW = tfStrides.get(3).intValue();
dD = tfDilation == null ? 1 : tfDilation.get(1).intValue();
dH = tfDilation == null ? 1 : tfDilation.get(2).intValue();
dW = tfDilation == null ? 1 : tfDilation.get(3).intValue();
}
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
DeConv3DConfig conv3DConfig = DeConv3DConfig.builder()
.kD(-1)
.kH(-1)
.kW(-1)
.sD(sD)
.sH(sW)
.sW(sH)
.dD(dD)
.dH(dH)
.dW(dW)
.isSameMode(isSameMode)
.dataFormat(dataFormat.equalsIgnoreCase(DeConv3DConfig.NCDHW) ? DeConv3DConfig.NCDHW : DeConv3DConfig.NDHWC)
.build();
this.config = conv3DConfig;
addArgs();
}
@Override
public String opName() {
return "deconv3d_tf";
}
@Override
public String[] tensorflowNames() {
return new String[]{"Conv3DBackpropInput", "Conv3DBackpropInputV2"};
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Backprop not yet implemented for " + getClass());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //inShape, weights, input
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(2));
}
}

View File

@ -39,6 +39,7 @@ import java.util.*;
@Slf4j
public class Concat extends DynamicCustomOp {
private int concatDimension = -1;
private boolean isDynamicAxis = false;
public Concat(){
@ -83,73 +84,11 @@ public class Concat extends DynamicCustomOp {
}
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
Map<String,PropertyMapping> concatMap = new HashMap<>();
val concatDimProps = PropertyMapping.builder()
.tfInputPosition(0)
.onnxAttrName("axis")
.build();
concatMap.put("concatDimension",concatDimProps);
Map<String,PropertyMapping> concatV2Map = new HashMap<>();
val concat2DimProps = PropertyMapping.builder()
//lalst position
.tfInputPosition(-1)
.onnxAttrName("axis")
.build();
concatV2Map.put("concatDimension",concat2DimProps);
//note that onnx is already covered here
ret.put(tensorflowNames()[0],concatMap);
ret.put(tensorflowNames()[1],concatV2Map);
return ret;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
int concatDimension = -1;
String input = null;
val inputCount = nodeDef.getInputCount();
for(int i = 0; i < inputCount; i++) {
if(nodeDef.getInput(i).contains("/concat_dim")) {
input = nodeDef.getInput(i);
break;
}
}
//older versions may specify a concat_dim, usually it's the last argument
if(input == null) {
input = nodeDef.getInput(nodeDef.getInputCount() - 1);
}
val variable = initWith.getVariable(input);
// concat dimension is only possible
if (variable != null) {
val arr = variable.getArr();
if (arr.length() == 1) {
concatDimension = arr.getInt(0);
}
this.concatDimension = concatDimension;
addIArgument(this.concatDimension);
log.trace("Concat dimension: {}", concatDimension);
}
//don't pass both iArg and last axis down to libnd4j
if(inputArguments().length == nodeDef.getInputCount()) {
val inputArgs = inputArguments();
removeInputArgument(inputArgs[inputArguments().length - 1]);
}
//TODO Fix this: https://github.com/eclipse/deeplearning4j/issues/8285
sameDiff.removeArgFromOp(input,this);
//TF uses dynamic axis - last argument is a scalar integer array for axis
addBArgument(true);
isDynamicAxis = true;
}
@Override
@ -159,12 +98,6 @@ public class Concat extends DynamicCustomOp {
return ret;
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}
@Override
public String onnxName() {
return "Concat";
@ -175,7 +108,6 @@ public class Concat extends DynamicCustomOp {
return "Concat";
}
@Override
public String[] tensorflowNames() {
return new String[] {"Concat","ConcatV2"};
@ -189,18 +121,32 @@ public class Concat extends DynamicCustomOp {
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable[] args = args();
SDVariable[] bpArgs = Arrays.copyOf(args, args.length + 1);
bpArgs[bpArgs.length-1] = i_v.get(0);
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
SDVariable[] bpArgs;
if(isDynamicAxis){
bpArgs = Arrays.copyOf(args, args.length + 2);
bpArgs[bpArgs.length - 1] = bpArgs[bpArgs.length - 3]; //Last input is axis -> move to end of bp args too
bpArgs[bpArgs.length - 2] = i_v.get(0);
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
} else {
bpArgs = Arrays.copyOf(args, args.length + 1);
bpArgs[bpArgs.length - 1] = i_v.get(0);
return Arrays.asList(new ConcatBp(sameDiff, concatDimension, bpArgs).outputVariables());
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
DataType first = dataTypes.get(0);
for( int i=1; i<dataTypes.size(); i++ ){
for( int i=1; i<dataTypes.size() - (isDynamicAxis ? 1 : 0); i++ ){
DataType dt = dataTypes.get(i);
Preconditions.checkState(first == dt, "All inputs must have same datatype - got %s and %s for inputs 0 and %s respectively", first, dt, i);
}
if(isDynamicAxis) {
Preconditions.checkState(dataTypes.get(dataTypes.size() - 1).isIntType(),
"For dynamic axis case, last datatype must be an integer type, got input types %s");
}
//Output type is same as input types
return Collections.singletonList(first);
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.bp;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
@ -42,6 +43,7 @@ import java.util.*;
@Slf4j
public class ConcatBp extends DynamicCustomOp {
private int concatDimension;
private boolean dynamicAxis;
public ConcatBp(){
@ -53,38 +55,30 @@ public class ConcatBp extends DynamicCustomOp {
* @param concatDimension
* @param inputsAndGrad Original inputs, followed by output gradient
*/
public ConcatBp(SameDiff sameDiff, int concatDimension, SDVariable... inputsAndGrad){
public ConcatBp(@NonNull SameDiff sameDiff, int concatDimension, @NonNull SDVariable... inputsAndGrad){
super(null, sameDiff, inputsAndGrad);
addIArgument(concatDimension);
this.concatDimension = concatDimension;
}
/**
*
* @param sameDiff SameDiff instance
* @param inputsGradAxis Inputs, gradient array, and axis
*/
public ConcatBp(@NonNull SameDiff sameDiff, @NonNull SDVariable... inputsGradAxis){
super(null, sameDiff, inputsGradAxis);
Preconditions.checkState(inputsGradAxis[inputsGradAxis.length-1].dataType().isIntType(),
"When using this constructor, the last input must be an integer array (for the axis)");
addBArgument(true); //Last argument
this.dynamicAxis = true;
}
@Override
public String opName() {
return "concat_bp";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
//No op
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
//No op
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public Op.Type opType() {
return Op.Type.CUSTOM;
@ -92,7 +86,7 @@ public class ConcatBp extends DynamicCustomOp {
@Override
public int getNumOutputs(){
return args().length - 1;
return args().length - 1 - (dynamicAxis ? 1 : 0);
}
@Override

View File

@ -1358,4 +1358,35 @@ public class LayerOpValidation extends BaseOpValidation {
.build());
assertEquals(outCC, outFC); //Fails here
}
@Test
public void testBiasAdd_nchw_nhwc() {
Nd4j.getRandom().setSeed(12345);
for(boolean nchw : new boolean[]{true, false}) {
log.info("Starting test: {}", nchw ? "nchw" : "nhwc");
SameDiff sameDiff = SameDiff.create();
SDVariable in = sameDiff.var("input", Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{2,4,3,3} : new long[]{2,3,3,4}));
SDVariable b = sameDiff.var("bias", Nd4j.rand(DataType.DOUBLE, new long[]{4}));
SDVariable bAdd = sameDiff.nn.biasAdd(in, b, nchw);
SDVariable loss = bAdd.std(true);
INDArray exp = in.getArr().dup();
if(nchw){
exp.addi(b.getArr().reshape(1,4,1,1));
} else {
exp.addi(b.getArr().reshape(1,1,1,4));
}
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(bAdd.name(), exp);
String err = OpValidation.validate(tc);
assertNull(err);
}
}
}

View File

@ -99,7 +99,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
"multinomial/.*",
//2019/11/02 AB - need deconv3d changes (for handling shape)
//2019/11/04 AB - disabled, pending libnd4j deconv3d_tf implementation
"conv3d_transpose.*"
};