Shyrma lstm layer bp (#370)

* - start working on bp for lstm

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further working on bp for lstmLayer

Signed-off-by: Yurii <iuriish@yahoo.com>

* - minor change

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 3

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 4

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 5

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 6

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 7

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 8

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 9

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing and fixing lstmLayerCellBp helper

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implement lstmLayerCellBp as separate op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implement lstmLayerBp as separate op (not tested)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fixing calculations of dLdWp and dLdb in lstmLayerCellBp

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 10

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fixing typo in lstmLayerTimeLoop

Signed-off-by: Yurii <iuriish@yahoo.com>

* - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on bp for lstmLayer 10

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing and fixing bugs in lstmLayer_bp op 1

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing and fixing bugs in lstmLayer_bp op 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - turn off heavy tests for cuda for lstmLayer_bp op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - forgot to nullify gradients at eliminated time steps (when sequnce length array is present )

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2020-04-13 13:21:51 +03:00 committed by GitHub
parent f1debe8c07
commit 23e4aa99ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 2896 additions and 167 deletions

View File

@ -403,7 +403,6 @@ NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::Launch
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
// u8 string constructors // u8 string constructors
/////////////////////////////////////////////////////////////////////////
NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) {
if (!DataTypeUtils::isS(dtype)) { if (!DataTypeUtils::isS(dtype)) {

View File

@ -10,7 +10,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !"); throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }

View File

@ -47,13 +47,13 @@ class ND4J_EXPORT GradCheck {
* opBP - back propagation operation * opBP - back propagation operation
* argsHolderFF - argument holder for feed forward operation * argsHolderFF - argument holder for feed forward operation
* argsHolderBP - argument holder for back propagation operation * argsHolderBP - argument holder for back propagation operation
* whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty array which means to check all arrays * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays
* IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.}
* loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM
* outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent
*/ */
static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM); const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector<int>& outArrsFFIdx = {});
}; };

View File

@ -372,16 +372,16 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
int xLenDim(0), yLenDim(0); int xLenDim(0), yLenDim(0);
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); throw std::runtime_error("MmulHelper::dot: X array must be vector !");
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim))
throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); throw std::runtime_error("MmulHelper::dot: Y array must be vector !");
if(Z != nullptr && !Z->isScalar()) if(Z != nullptr && !Z->isScalar())
throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); throw std::runtime_error("MmulHelper::dot: Z array must be scalar !");
const auto length = X->lengthOf(); const auto length = X->lengthOf();
if(Y->lengthOf() != length) if(Y->lengthOf() != length)
throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !");
if(Z == nullptr) if(Z == nullptr)
Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext());

View File

@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) { const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss, const std::vector<int>& outArrsFFIdx) {
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
@ -82,6 +82,16 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
int numOutArrs = outArrsFF.size(); int numOutArrs = outArrsFF.size();
double scorePlus = 0.; double scorePlus = 0.;
if(!outArrsFFIdx.empty()) {
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0);
}
}
else {
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
@ -89,12 +99,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0); scorePlus += tmpScalar.e<double>(0);
} }
}
// subtract epsilon, feed forward // subtract epsilon, feed forward
inArrsFF[i]->p<double>(j, orig - EPSILON); inArrsFF[i]->p<double>(j, orig - EPSILON);
outArrsFF = opFF.execute(argsHolderFF); outArrsFF = opFF.execute(argsHolderFF);
double scoreMinus = 0.; double scoreMinus = 0.;
if(!outArrsFFIdx.empty()) {
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0);
}
}
else {
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
@ -102,6 +123,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0); scoreMinus += tmpScalar.e<double>(0);
} }
}
// restore initial element value // restore initial element value
inArrsFF[i]->p<double>(j, orig); inArrsFF[i]->p<double>(j, orig);
@ -120,7 +142,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
throw std::runtime_error(""); throw std::runtime_error("");
} }
// printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad); // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad);
// calculate relative error // calculate relative error
double relError; double relError;
@ -134,7 +156,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
if(math::nd4j_abs<double>(analyticGrad - numericalGrad) < MINABSERR) if(math::nd4j_abs<double>(analyticGrad - numericalGrad) < MINABSERR)
continue; continue;
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad); printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad);
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j); printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
return false; return false;
} }

View File

@ -253,7 +253,8 @@
(45, ReversePow), \ (45, ReversePow), \
(46, DivideNoNan), \ (46, DivideNoNan), \
(47, IGamma), \ (47, IGamma), \
(48, IGammac) (48, IGammac), \
(49, RELUDerivative)

View File

@ -24,10 +24,10 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/lstmLayer.h> #include<ops/declarable/helpers/lstmLayer.h>
namespace sd { namespace sd {
namespace ops { namespace ops {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
@ -43,7 +43,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc) // c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't // ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct) // ht = ot ◦ tanh(ct)
@ -72,26 +72,26 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// 2) [2, nOut, 4*nOut] when directionMode >= 2 // 2) [2, nOut, 4*nOut] when directionMode >= 2
// ******* // *******
// peephole weights Wp: // peephole weights Wp, optional:
// 1) [3*nOut] when directionMode < 2 // 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2 // 2) [2, 3*nOut] when directionMode >= 2
// ******* // *******
// biases b: // biases b, optional:
// 1) [4*nOut] when directionMode < 2 // 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2 // 2) [2, 4*nOut] when directionMode >= 2
// ******* // *******
// sequence length array seqLen: // sequence length array seqLen, optional:
// 1) [bS] always // 1) [bS]
// ******* // *******
// initial output hI: // initial output hI, optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
// ******* // *******
// initial cell state cI (same shape as in hI): // initial cell state cI (same shape as in hI), optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// OUTPUTS: // OUTPUTS:
// ******* // *******
// output h: // output h, optional:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
@ -109,19 +109,19 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// ******* // *******
// output at last step hL: // output at last step hL, optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
// ******* // *******
// cell state at last step cL (same shape as in hL): // cell state at last step cL (same shape as in hL), optional:
// 1) [bS, nOut] when directionMode < 2 // 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2 // 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot // !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot // !!! dimension 3*nOut implies order it, ft, ot
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
@ -135,8 +135,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// evaluate dimensions // evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4; const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations // inputs validations
@ -323,9 +323,9 @@ DECLARE_SHAPE_FN(lstmLayer) {
const auto Wr = INPUT_VARIABLE(2); // recurrent weights const auto Wr = INPUT_VARIABLE(2); // recurrent weights
// evaluate dimensions // evaluate dimensions
const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) ); const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4; const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
DataType type; DataType type;
@ -398,6 +398,412 @@ DECLARE_SHAPE_FN(lstmLayer) {
} }
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// sL - sequence length, number of time steps
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// *******
// input x:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// input weights Wx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// recurrent weights Wr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// peephole weights Wp, optional:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// biases b, optional:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// *******
// sequence length array seqLen, optional:
// 1) [bS]
// *******
// initial output hI, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// initial cell state cI (same shape as in hI), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs. output dLdh, optional:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// *******
// gradient vs output at last time step dLdhL, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// OUTPUTS:
// *******
// gradient vs. input dLdx:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// gradient vs. input weights dLdWx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// gradient vs. recurrent weights dLdWr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// gradient vs. peephole weights dLdWp, optional:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// gradient vs. biases dLdb, optional:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// gradient vs. sequence length array dLdsL, optional (do not calculate it!!!):
// 1) [bS] always
// *******
// gradient vs. initial output dLdhI, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(3); // activation for cell state (c)
const auto outAct = INT_ARG(4); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1}
const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given
const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode);
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !");
REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !");
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
count = 3;
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output
const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step
const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step
count = 3;
auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input
auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights
auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights
auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases
auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!!
auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output
auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state
auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights
// evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
if(directionMode < 2) { // no bidirectional
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
// gradient vs. output at last time step validation
if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
// gradient vs. cell state at last time step validation
if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
}
else { // bidirectional
// Wx validation
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
// gradient vs. output at last time step validation
if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
// gradient vs. cell state at last time step validation
if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
}
// gradient vs. output validation
if(dLdh) {
int factor = directionMode <= 2 ? 1 : 2;
std::vector<Nd4jLong> expdLdhShape;
if(dataFormat == 0) expdLdhShape = std::vector<Nd4jLong>{sL, bS, factor*nOut};
else if(dataFormat == 1) expdLdhShape = std::vector<Nd4jLong>{bS, sL, factor*nOut};
else if(dataFormat == 2) expdLdhShape = std::vector<Nd4jLong>{bS, factor*nOut, sL};
else expdLdhShape = std::vector<Nd4jLong>{sL, 2, bS, nOut};
REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
}
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
if(directionMode == 0) { // forward
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
}
else if(directionMode == 1) { // backward
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
}
else { // bidirectional
NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0});
NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0});
NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0});
NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0});
NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0});
NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0});
NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0});
NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0});
NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr),
*dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr),
*dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr),
*dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr);
if(Wp) {
WpFwd = new NDArray((*Wp)({0,1, 0,0}));
WpBwd = new NDArray((*Wp)({1,2, 0,0}));
dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0}));
dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0}));
}
if(b) {
bFwd = new NDArray((*b)({0,1, 0,0}));
bBwd = new NDArray((*b)({1,2, 0,0}));
dLdbFwd = new NDArray((*dLdb)({0,1, 0,0}));
dLdbBwd = new NDArray((*dLdb)({1,2, 0,0}));
}
if(hI) {
hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0}));
hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0}));
dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0}));
dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0}));
}
if(cI) {
cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0}));
cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0}));
dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0}));
dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0}));
}
if(dLdhL) {
dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0}));
dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0}));
}
if(dLdcL) {
dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0}));
dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0}));
}
// FIXME looks like sum (directionMode == 2) is impossible for backprop
if(dLdh) {
if(directionMode == 2) { // sum
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !");
// dLdhFwd = dLdh;
// dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content
}
else if(directionMode == 3) { // concat
dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0}));
dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0}));
}
else { // directionMode == 4
dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0}));
dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0}));
}
}
helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd);
NDArray dLdxBwd = dLdx->ulike();
helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd);
*dLdx += dLdxBwd;
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd;
delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd;
delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd;
delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd;
if(dLdhFwd != dLdh)
delete dLdhFwd;
}
return Status::OK();
}
DECLARE_TYPES(lstmLayer_bp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayer_bp) {
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
int count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
std::vector<Nd4jLong*> outShapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
if(b != nullptr)
outShapes.push_back(b->getShapeInfo());
if(seqLen != nullptr)
outShapes.push_back(seqLen->getShapeInfo());
if(hI != nullptr)
outShapes.push_back(hI->getShapeInfo());
if(cI != nullptr)
outShapes.push_back(cI->getShapeInfo());
if(Wp != nullptr)
outShapes.push_back(Wp->getShapeInfo());
return new ShapeList(outShapes);
}
} }
} }

View File

@ -0,0 +1,339 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_lstmLayerCell)
#include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/lstmLayer.h>
namespace sd {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// input x: [bS, nIn] or [nIn]
// input weights Wx: [nIn, 4*nOut]
// recurrent weights Wr: [nOut, 4*nOut]
// initial (previous) output hI: [bS, nOut] or [nOut]
// initial (previous) cell state cI: [bS, nOut] or [nOut]
// biases b (optional): [4*nOut]
// peephole weights Wp (optional): [3*nOut]
// OUTPUTS:
// current output h: [bS, nOut] or [nOut]
// current cell state c: [bS, nOut] or [nOut]
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(1); // activation for cell state (c)
const auto outAct = INT_ARG(2); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count++); // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !");
auto h = OUTPUT_VARIABLE(0);
auto c = OUTPUT_VARIABLE(1);
// evaluate dimensions
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
const Nd4jLong nIn = x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// initial output/cell validation
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c);
return Status::OK();
}
DECLARE_TYPES(lstmLayerCell) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayerCell) {
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
uint count = hasBiases ? 4 : 3;
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count); // initial cell state
return new ShapeList({hI->getShapeInfo(), cI->getShapeInfo()});
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// input x: [bS, nIn] or [nIn]
// input weights Wx: [nIn, 4*nOut]
// recurrent weights Wr: [nOut, 4*nOut]
// initial (previous) output hI: [bS, nOut] or [nOut]
// initial (previous) cell state cI: [bS, nOut] or [nOut]
// gradient wrt output dLdh: [bS, nOut] or [nOut]
// gradient wrt cell state dLdc: [bS, nOut] or [nOut]
// peephole weights Wp (optional): [3*nOut]
// biases b (optional): [4*nOut]
// OUTPUTS:
// gradient wrt x dLdx: [bS, nIn] or [nIn]
// gradient wrt Wx dLdWx: [nIn, 4*nOut]
// gradient wrt Wr dLdWr: [nOut, 4*nOut]
// gradient wrt hI dLdhI: [bS, nOut] or [nOut]
// gradient wrt cI dLdcI: [bS, nOut] or [nOut]
// gradient wrt b dLdb (optional): [4*nOut]
// gradient wrt Wp dLdWp (optional): [3*nOut]
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(1); // activation for cell state (c)
const auto outAct = INT_ARG(2); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count++); // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !");
count = 3;
auto dLdx = OUTPUT_VARIABLE(0);
auto dLdWx = OUTPUT_VARIABLE(1);
auto dLdWr = OUTPUT_VARIABLE(2);
auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr;
auto dLdhI = OUTPUT_VARIABLE(count++);
auto dLdcI = OUTPUT_VARIABLE(count++);
auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr;
// evaluate dimensions
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
const Nd4jLong nIn = x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// initial output/cell validation
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str());
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
std::vector<Nd4jLong> zShape = x->rankOf() == 1 ? std::vector<Nd4jLong>({4*nOut}) : std::vector<Nd4jLong>({bS, 4*nOut});
NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext());
NDArray a = z.ulike();
NDArray h = cI->ulike();
NDArray c = cI->ulike();
helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);
helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
return Status::OK();
}
DECLARE_TYPES(lstmLayerCellBp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayerCellBp) {
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
uint count = 3;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto hI = INPUT_VARIABLE(count++); // initial output
const auto cI = INPUT_VARIABLE(count++); // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
std::vector<Nd4jLong*> shapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
if(b != nullptr)
shapes.push_back(b->getShapeInfo());
shapes.push_back(hI->getShapeInfo());
shapes.push_back(cI->getShapeInfo());
if(Wp != nullptr)
shapes.push_back(Wp->getShapeInfo());
return new ShapeList(shapes);
}
}
}
#endif

View File

@ -149,6 +149,13 @@ namespace ops {
DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2);
#endif #endif
#if NOT_EXCLUDED(OP_lstmLayerCell)
DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3);
#endif
#if NOT_EXCLUDED(OP_lstmLayerCell)
DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3);
#endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
@ -236,6 +243,11 @@ namespace ops {
DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5);
#endif #endif
//////////////////////////////////////////////////////////////////////////
#if NOT_EXCLUDED(OP_lstmLayer)
DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5);
#endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,6 @@
#define LIBND4J_LSTMLAYER_H #define LIBND4J_LSTMLAYER_H
#include <ops/declarable/helpers/helpers.h> #include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/activations.h>
namespace sd { namespace sd {
namespace ops { namespace ops {
@ -34,6 +33,20 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra
const std::vector<float>& params, const std::vector<float>& params,
NDArray* h, NDArray* c); NDArray* h, NDArray* c);
//////////////////////////////////////////////////////////////////////////
// this auxiliary ff should be running before backprop
void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
NDArray* z, NDArray* a, NDArray* h, NDArray* c);
//////////////////////////////////////////////////////////////////////////
void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const NDArray* dLdh, const NDArray* dLdc,
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
@ -42,71 +55,11 @@ void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const ND
NDArray* h, NDArray* hL, NDArray* cL); NDArray* h, NDArray* hL, NDArray* cL);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp,
switch (opId) { const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
case 0: const std::vector<float>& params, const bool forward,
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, z); NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp);
break;
case 1:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, z);
break;
case 2:
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, z);
break;
case 3: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, z, &args);
break;
}
case 4:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, z);
break;
case 5:
helpers::thresholdRelu(x.getContext(), x, alpha, z);
break;
case 6: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, z, &args);
break;
}
case 7:
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, z);
break;
case 8:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, z);
break;
case 9:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, z);
break;
case 10:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, z);
break;
default:
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
}
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
if(dataFormat == 0 || dataFormat == 3)
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
if(dataFormat == 1)
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
if(dataFormat == 0 || dataFormat == 3)
return t * bS + b; // TNS: shape [sL, bS, nIn]
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
}
} }

View File

@ -1441,7 +1441,7 @@ namespace simdOps {
} }
op_def static Z op(X d1) { op_def static Z op(X d1) {
return d1; return static_cast<Z>(d1);
} }
}; };
@ -2434,6 +2434,19 @@ namespace simdOps {
} }
}; };
template <typename X, typename Y, typename Z>
class RELUDerivative {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda
op_def static Z op(X d1, Y d2, Z *params) {
auto xt = static_cast<Z>(d1);
auto xf = static_cast<Z>(d2);
return xt > xf ? static_cast<Z>(1.f) : static_cast<Z>(0.f);
}
};
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
class SXELogitsSmoother { class SXELogitsSmoother {
public: public:

View File

@ -1109,6 +1109,7 @@ TEST_F(DeclarableOpsTests13, mergeavg_bp_1) {
} }
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_1) { TEST_F(DeclarableOpsTests13, lstmLayer_1) {
@ -1200,7 +1201,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) {
const auto hasInitC = true; // initial cell state is provided const auto hasInitC = true; // initial cell state is provided
const auto hasPH = false; // peephole connections are absent const auto hasPH = false; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step const auto retLastH = true; // return output at last time step
const auto retLastC = true; // return cells state at last time step const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping const double cellClip = 0; // do not apply clipping
@ -2061,10 +2062,528 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL)); ASSERT_TRUE(expCL.equalsTo(cL));
#endif #endif
} }
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 3;
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // dLdh per each time step
const auto retLastH = true; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::SUM, {0});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 3;
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = false; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN);
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 3;
const int dataFormat = 1; // [bS,sL,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = false; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN, {0});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
const int sL = 4;
const int bS = 3;
const int nIn = 3;
const int nOut = 2;
const int dataFormat = 2; // [bS, nIn, sL]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 3;
const int dataFormat = 1; // [bS,sL,nIn]
const int directionMode = 1; // backward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // dLdh per each time step
const auto retLastH = true; // output at last time step
const auto retLastC = false; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN, {0});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 2;
const int dataFormat = 2; // [bS, nIn, sL]
const int directionMode = 1; // backward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 2;
const int dataFormat = 2; // [bS, nIn, sL]
const int directionMode = 2; // bidirectional sum
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // dLdh per each time step
const auto retLastH = true; // output at last time step
const auto retLastC = false; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 2;
const int dataFormat = 1; // [bS,sL,nIn]
const int directionMode = 3; // bidirectional concat
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 2;
const int dataFormat = 3; // [sL, bS, nIn]
const int directionMode = 4; // bidirectional extra output dim
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
ASSERT_TRUE(isGradCorrect);
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, batchnorm_test1) { TEST_F(DeclarableOpsTests13, batchnorm_test1) {

View File

@ -844,5 +844,78 @@ TEST_F(PlaygroundTests, my) {
printf("time: %i \n", time); printf("time: %i \n", time);
} }
///////////////////////////////////////////////////////////////////
TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// const int nIn = 8;
// const int nOut = 6;
const float cellClip = 1.1; // clipping value
const Nd4jLong gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
const Nd4jLong cellAct = 0; // tanh activation for cell state
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
const Nd4jLong outAct = 0; // tanh activation for output
const float outAlpha = 0; // alpha value for output activation, not required for tanh
const float outBeta = 0; // beta value for output activation, not required for tanh
NDArray x ('c', {bS, nIn}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdc('c', {bS, nOut}, sd::DataType::DOUBLE);
// NDArray x ('c', {nIn}, sd::DataType::DOUBLE);
// NDArray hI('c', {nOut}, sd::DataType::DOUBLE);
// NDArray cI('c', {nOut}, sd::DataType::DOUBLE);
// NDArray dLdh('c', {nOut}, sd::DataType::DOUBLE);
// NDArray dLdc('c', {nOut}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b ('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
x.linspace(-4,1);
hI.linspace(-2.5,0.5);
cI.linspace(-3,0.5);
Wx.linspace(0,0.1);
Wr.linspace(3,-0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
// x.assign(1.);
// hI.assign(2.);
// cI.assign(3.);
// Wx.assign(0.5);
// Wr.assign(0.5);
// Wp.assign(0.75);
// b.assign(0.7);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {gateAct, cellAct, outAct};
// std::vector<bool> bArgs = {false, false};
// const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, iArgs, bArgs);
std::vector<bool> bArgs = {true, true};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
sd::ops::lstmLayerCell opFF;
sd::ops::lstmLayerCellBp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, true, true, true});
}
*/ */