Shyrma lstm layer bp (#370)
* - start working on bp for lstm Signed-off-by: Yurii <iuriish@yahoo.com> * - further working on bp for lstmLayer Signed-off-by: Yurii <iuriish@yahoo.com> * - minor change Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 3 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 4 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 5 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 6 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 7 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 8 Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 9 Signed-off-by: Yurii <iuriish@yahoo.com> * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii <iuriish@yahoo.com> * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on bp for lstmLayer 10 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii <iuriish@yahoo.com> * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
f1debe8c07
commit
23e4aa99ad
|
@ -403,7 +403,6 @@ NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::Launch
|
|||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
// u8 string constructors
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) {
|
||||
|
||||
if (!DataTypeUtils::isS(dtype)) {
|
||||
|
|
|
@ -10,7 +10,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::
|
|||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
|
||||
|
||||
if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
|
||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
||||
nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n","");
|
||||
throw std::runtime_error("Shapes mismach");
|
||||
}
|
||||
|
||||
|
|
|
@ -47,13 +47,13 @@ class ND4J_EXPORT GradCheck {
|
|||
* opBP - back propagation operation
|
||||
* argsHolderFF - argument holder for feed forward operation
|
||||
* argsHolderBP - argument holder for back propagation operation
|
||||
* whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty array which means to check all arrays
|
||||
* whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays
|
||||
* IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.}
|
||||
* loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM
|
||||
* outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent
|
||||
*/
|
||||
static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
||||
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM);
|
||||
|
||||
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector<int>& outArrsFFIdx = {});
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -372,16 +372,16 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
|
|||
int xLenDim(0), yLenDim(0);
|
||||
|
||||
if(!shape::isCommonVector(X->getShapeInfo(), xLenDim))
|
||||
throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !");
|
||||
throw std::runtime_error("MmulHelper::dot: X array must be vector !");
|
||||
if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim))
|
||||
throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !");
|
||||
throw std::runtime_error("MmulHelper::dot: Y array must be vector !");
|
||||
if(Z != nullptr && !Z->isScalar())
|
||||
throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !");
|
||||
throw std::runtime_error("MmulHelper::dot: Z array must be scalar !");
|
||||
|
||||
const auto length = X->lengthOf();
|
||||
|
||||
if(Y->lengthOf() != length)
|
||||
throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !");
|
||||
throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !");
|
||||
|
||||
if(Z == nullptr)
|
||||
Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext());
|
||||
|
|
|
@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
||||
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) {
|
||||
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss, const std::vector<int>& outArrsFFIdx) {
|
||||
|
||||
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
|
||||
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
|
||||
|
@ -82,12 +82,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
int numOutArrs = outArrsFF.size();
|
||||
double scorePlus = 0.;
|
||||
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scorePlus += tmpScalar.e<double>(0);
|
||||
if(!outArrsFFIdx.empty()) {
|
||||
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scorePlus += tmpScalar.e<double>(0);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scorePlus += tmpScalar.e<double>(0);
|
||||
}
|
||||
}
|
||||
|
||||
// subtract epsilon, feed forward
|
||||
|
@ -95,12 +106,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
outArrsFF = opFF.execute(argsHolderFF);
|
||||
double scoreMinus = 0.;
|
||||
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scoreMinus += tmpScalar.e<double>(0);
|
||||
if(!outArrsFFIdx.empty()) {
|
||||
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scoreMinus += tmpScalar.e<double>(0);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scoreMinus += tmpScalar.e<double>(0);
|
||||
}
|
||||
}
|
||||
|
||||
// restore initial element value
|
||||
|
@ -120,7 +142,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
throw std::runtime_error("");
|
||||
}
|
||||
|
||||
// printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad);
|
||||
// printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad);
|
||||
|
||||
// calculate relative error
|
||||
double relError;
|
||||
|
@ -134,7 +156,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
|
||||
if(math::nd4j_abs<double>(analyticGrad - numericalGrad) < MINABSERR)
|
||||
continue;
|
||||
printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad);
|
||||
printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad);
|
||||
printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -253,7 +253,8 @@
|
|||
(45, ReversePow), \
|
||||
(46, DivideNoNan), \
|
||||
(47, IGamma), \
|
||||
(48, IGammac)
|
||||
(48, IGammac), \
|
||||
(49, RELUDerivative)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -24,10 +24,10 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
#include<ops/declarable/helpers/lstmLayer.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||
|
||||
|
@ -43,7 +43,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
|||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
|
@ -72,26 +72,26 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
|||
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// peephole weights Wp:
|
||||
// peephole weights Wp, optional:
|
||||
// 1) [3*nOut] when directionMode < 2
|
||||
// 2) [2, 3*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// biases b:
|
||||
// biases b, optional:
|
||||
// 1) [4*nOut] when directionMode < 2
|
||||
// 2) [2, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// sequence length array seqLen:
|
||||
// 1) [bS] always
|
||||
// sequence length array seqLen, optional:
|
||||
// 1) [bS]
|
||||
|
||||
// *******
|
||||
// initial output hI:
|
||||
// initial output hI, optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// initial cell state cI (same shape as in hI):
|
||||
// initial cell state cI (same shape as in hI), optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
|
@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
|||
// OUTPUTS:
|
||||
|
||||
// *******
|
||||
// output h:
|
||||
// output h, optional:
|
||||
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
||||
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
|
||||
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
|
||||
|
@ -109,19 +109,19 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
|||
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
|
||||
|
||||
// *******
|
||||
// output at last step hL:
|
||||
// output at last step hL, optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// cell state at last step cL (same shape as in hL):
|
||||
// cell state at last step cL (same shape as in hL), optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||
// !!! dimension 3*nOut implies order it, ft, ot
|
||||
|
||||
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
|
||||
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
|
||||
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||
|
||||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
|
@ -135,8 +135,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
|||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only
|
||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only
|
||||
|
||||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||
|
@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
|||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
// inputs validations
|
||||
|
@ -323,9 +323,9 @@ DECLARE_SHAPE_FN(lstmLayer) {
|
|||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) );
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
DataType type;
|
||||
|
@ -398,6 +398,412 @@ DECLARE_SHAPE_FN(lstmLayer) {
|
|||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) {
|
||||
|
||||
// equations (no peephole connections)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// equations (peephole connections are present)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// notations:
|
||||
// bS - batch size
|
||||
// sL - sequence length, number of time steps
|
||||
// nIn - input size
|
||||
// nOut - output size (hidden size)
|
||||
|
||||
// INPUTS:
|
||||
|
||||
// *******
|
||||
// input x:
|
||||
// 1) [sL, bS, nIn] when dataFormat == 0
|
||||
// 2) [bS, sL, nIn] when dataFormat == 1
|
||||
// 3) [bS, nIn, sL] when dataFormat == 2
|
||||
|
||||
// *******
|
||||
// input weights Wx:
|
||||
// 1) [nIn, 4*nOut] when directionMode < 2
|
||||
// 2) [2, nIn, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// recurrent weights Wr:
|
||||
// 1) [nOut, 4*nOut] when directionMode < 2
|
||||
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// peephole weights Wp, optional:
|
||||
// 1) [3*nOut] when directionMode < 2
|
||||
// 2) [2, 3*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// biases b, optional:
|
||||
// 1) [4*nOut] when directionMode < 2
|
||||
// 2) [2, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// sequence length array seqLen, optional:
|
||||
// 1) [bS]
|
||||
|
||||
// *******
|
||||
// initial output hI, optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// initial cell state cI (same shape as in hI), optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// gradient vs. output dLdh, optional:
|
||||
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
||||
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
|
||||
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
|
||||
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
|
||||
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
|
||||
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
|
||||
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
|
||||
|
||||
// *******
|
||||
// gradient vs output at last time step dLdhL, optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
|
||||
// OUTPUTS:
|
||||
|
||||
// *******
|
||||
// gradient vs. input dLdx:
|
||||
// 1) [sL, bS, nIn] when dataFormat == 0
|
||||
// 2) [bS, sL, nIn] when dataFormat == 1
|
||||
// 3) [bS, nIn, sL] when dataFormat == 2
|
||||
|
||||
// *******
|
||||
// gradient vs. input weights dLdWx:
|
||||
// 1) [nIn, 4*nOut] when directionMode < 2
|
||||
// 2) [2, nIn, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// gradient vs. recurrent weights dLdWr:
|
||||
// 1) [nOut, 4*nOut] when directionMode < 2
|
||||
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// gradient vs. peephole weights dLdWp, optional:
|
||||
// 1) [3*nOut] when directionMode < 2
|
||||
// 2) [2, 3*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// gradient vs. biases dLdb, optional:
|
||||
// 1) [4*nOut] when directionMode < 2
|
||||
// 2) [2, 4*nOut] when directionMode >= 2
|
||||
|
||||
// gradient vs. sequence length array dLdsL, optional (do not calculate it!!!):
|
||||
// 1) [bS] always
|
||||
|
||||
// *******
|
||||
// gradient vs. initial output dLdhI, optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
|
||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||
// !!! dimension 3*nOut implies order it, ft, ot
|
||||
|
||||
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
|
||||
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||
|
||||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates
|
||||
const auto cellAct = INT_ARG(3); // activation for cell state (c)
|
||||
const auto outAct = INT_ARG(4); // activation for output (h)
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||
const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1}
|
||||
const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given
|
||||
const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given
|
||||
|
||||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||||
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||||
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||||
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||||
|
||||
uint count = 1;
|
||||
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||||
|
||||
REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode);
|
||||
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !");
|
||||
REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !");
|
||||
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
|
||||
count = 3;
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
|
||||
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
|
||||
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
|
||||
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||
const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output
|
||||
const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step
|
||||
const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step
|
||||
|
||||
count = 3;
|
||||
auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input
|
||||
auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights
|
||||
auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights
|
||||
auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases
|
||||
auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!!
|
||||
auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output
|
||||
auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state
|
||||
auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
// inputs validations
|
||||
if(directionMode < 2) { // no bidirectional
|
||||
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
// initial output validation
|
||||
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||
// initial cell validation
|
||||
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||
// peephole weights validation
|
||||
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||
// gradient vs. output at last time step validation
|
||||
if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
|
||||
// gradient vs. cell state at last time step validation
|
||||
if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
|
||||
}
|
||||
else { // bidirectional
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
// initial output validation
|
||||
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||
// initial cell validation
|
||||
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||
// peephole weights validation
|
||||
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||
// gradient vs. output at last time step validation
|
||||
if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str());
|
||||
// gradient vs. cell state at last time step validation
|
||||
if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str());
|
||||
}
|
||||
|
||||
// gradient vs. output validation
|
||||
if(dLdh) {
|
||||
int factor = directionMode <= 2 ? 1 : 2;
|
||||
std::vector<Nd4jLong> expdLdhShape;
|
||||
if(dataFormat == 0) expdLdhShape = std::vector<Nd4jLong>{sL, bS, factor*nOut};
|
||||
else if(dataFormat == 1) expdLdhShape = std::vector<Nd4jLong>{bS, sL, factor*nOut};
|
||||
else if(dataFormat == 2) expdLdhShape = std::vector<Nd4jLong>{bS, factor*nOut, sL};
|
||||
else expdLdhShape = std::vector<Nd4jLong>{sL, 2, bS, nOut};
|
||||
REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
|
||||
}
|
||||
|
||||
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
|
||||
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||||
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||||
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||||
|
||||
if(directionMode == 0) { // forward
|
||||
|
||||
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
|
||||
}
|
||||
else if(directionMode == 1) { // backward
|
||||
|
||||
helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp);
|
||||
}
|
||||
else { // bidirectional
|
||||
|
||||
NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0});
|
||||
NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0});
|
||||
NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0});
|
||||
NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0});
|
||||
|
||||
NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0});
|
||||
NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0});
|
||||
NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0});
|
||||
NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0});
|
||||
|
||||
NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr),
|
||||
*dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr),
|
||||
*dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr),
|
||||
*dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr);
|
||||
|
||||
if(Wp) {
|
||||
WpFwd = new NDArray((*Wp)({0,1, 0,0}));
|
||||
WpBwd = new NDArray((*Wp)({1,2, 0,0}));
|
||||
dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0}));
|
||||
dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0}));
|
||||
}
|
||||
if(b) {
|
||||
bFwd = new NDArray((*b)({0,1, 0,0}));
|
||||
bBwd = new NDArray((*b)({1,2, 0,0}));
|
||||
dLdbFwd = new NDArray((*dLdb)({0,1, 0,0}));
|
||||
dLdbBwd = new NDArray((*dLdb)({1,2, 0,0}));
|
||||
}
|
||||
if(hI) {
|
||||
hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0}));
|
||||
hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0}));
|
||||
dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0}));
|
||||
dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
if(cI) {
|
||||
cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0}));
|
||||
cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0}));
|
||||
dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0}));
|
||||
dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
if(dLdhL) {
|
||||
dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0}));
|
||||
dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
if(dLdcL) {
|
||||
dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0}));
|
||||
dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
|
||||
// FIXME looks like sum (directionMode == 2) is impossible for backprop
|
||||
if(dLdh) {
|
||||
if(directionMode == 2) { // sum
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !");
|
||||
// dLdhFwd = dLdh;
|
||||
// dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content
|
||||
}
|
||||
else if(directionMode == 3) { // concat
|
||||
dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0}));
|
||||
dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0}));
|
||||
}
|
||||
else { // directionMode == 4
|
||||
dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0}));
|
||||
dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0}));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd);
|
||||
NDArray dLdxBwd = dLdx->ulike();
|
||||
helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd);
|
||||
|
||||
*dLdx += dLdxBwd;
|
||||
|
||||
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd;
|
||||
delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd;
|
||||
delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd;
|
||||
delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd;
|
||||
|
||||
if(dLdhFwd != dLdh)
|
||||
delete dLdhFwd;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(lstmLayer_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(lstmLayer_bp) {
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||
|
||||
int count = 3;
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
|
||||
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
|
||||
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
|
||||
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||
|
||||
std::vector<Nd4jLong*> outShapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
|
||||
|
||||
if(b != nullptr)
|
||||
outShapes.push_back(b->getShapeInfo());
|
||||
if(seqLen != nullptr)
|
||||
outShapes.push_back(seqLen->getShapeInfo());
|
||||
if(hI != nullptr)
|
||||
outShapes.push_back(hI->getShapeInfo());
|
||||
if(cI != nullptr)
|
||||
outShapes.push_back(cI->getShapeInfo());
|
||||
if(Wp != nullptr)
|
||||
outShapes.push_back(Wp->getShapeInfo());
|
||||
|
||||
return new ShapeList(outShapes);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include<ops/declarable/helpers/lstmLayer.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) {
|
||||
|
||||
// equations (no peephole connections)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// equations (peephole connections are present)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// notations:
|
||||
// bS - batch size
|
||||
// nIn - input size
|
||||
// nOut - output size (hidden size)
|
||||
|
||||
// INPUTS:
|
||||
// input x: [bS, nIn] or [nIn]
|
||||
// input weights Wx: [nIn, 4*nOut]
|
||||
// recurrent weights Wr: [nOut, 4*nOut]
|
||||
// initial (previous) output hI: [bS, nOut] or [nOut]
|
||||
// initial (previous) cell state cI: [bS, nOut] or [nOut]
|
||||
// biases b (optional): [4*nOut]
|
||||
// peephole weights Wp (optional): [3*nOut]
|
||||
|
||||
// OUTPUTS:
|
||||
// current output h: [bS, nOut] or [nOut]
|
||||
// current cell state c: [bS, nOut] or [nOut]
|
||||
|
||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||
// !!! dimension 3*nOut implies order it, ft, ot
|
||||
|
||||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
|
||||
const auto cellAct = INT_ARG(1); // activation for cell state (c)
|
||||
const auto outAct = INT_ARG(2); // activation for output (h)
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||||
|
||||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||||
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||||
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||||
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||||
|
||||
uint count = 1;
|
||||
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||||
|
||||
count = 3;
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||||
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
|
||||
|
||||
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !");
|
||||
|
||||
auto h = OUTPUT_VARIABLE(0);
|
||||
auto c = OUTPUT_VARIABLE(1);
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
|
||||
const Nd4jLong nIn = x->sizeAt(-1);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
// inputs validations
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||
// initial output/cell validation
|
||||
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
|
||||
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
// peephole weights validation
|
||||
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||
|
||||
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
|
||||
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||||
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||||
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||||
|
||||
helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(lstmLayerCell) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(lstmLayerCell) {
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
|
||||
uint count = hasBiases ? 4 : 3;
|
||||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||
const auto cI = INPUT_VARIABLE(count); // initial cell state
|
||||
|
||||
return new ShapeList({hI->getShapeInfo(), cI->getShapeInfo()});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
|
||||
|
||||
// equations (no peephole connections)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// equations (peephole connections are present)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// notations:
|
||||
// bS - batch size
|
||||
// nIn - input size
|
||||
// nOut - output size (hidden size)
|
||||
|
||||
// INPUTS:
|
||||
// input x: [bS, nIn] or [nIn]
|
||||
// input weights Wx: [nIn, 4*nOut]
|
||||
// recurrent weights Wr: [nOut, 4*nOut]
|
||||
// initial (previous) output hI: [bS, nOut] or [nOut]
|
||||
// initial (previous) cell state cI: [bS, nOut] or [nOut]
|
||||
// gradient wrt output dLdh: [bS, nOut] or [nOut]
|
||||
// gradient wrt cell state dLdc: [bS, nOut] or [nOut]
|
||||
// peephole weights Wp (optional): [3*nOut]
|
||||
// biases b (optional): [4*nOut]
|
||||
|
||||
// OUTPUTS:
|
||||
// gradient wrt x dLdx: [bS, nIn] or [nIn]
|
||||
// gradient wrt Wx dLdWx: [nIn, 4*nOut]
|
||||
// gradient wrt Wr dLdWr: [nOut, 4*nOut]
|
||||
// gradient wrt hI dLdhI: [bS, nOut] or [nOut]
|
||||
// gradient wrt cI dLdcI: [bS, nOut] or [nOut]
|
||||
// gradient wrt b dLdb (optional): [4*nOut]
|
||||
// gradient wrt Wp dLdWp (optional): [3*nOut]
|
||||
|
||||
|
||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||
// !!! dimension 3*nOut implies order it, ft, ot
|
||||
|
||||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
|
||||
const auto cellAct = INT_ARG(1); // activation for cell state (c)
|
||||
const auto outAct = INT_ARG(2); // activation for output (h)
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||||
|
||||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||||
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||||
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||||
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||||
|
||||
uint count = 1;
|
||||
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||||
|
||||
count = 3;
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||||
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||
const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output
|
||||
|
||||
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !");
|
||||
|
||||
count = 3;
|
||||
auto dLdx = OUTPUT_VARIABLE(0);
|
||||
auto dLdWx = OUTPUT_VARIABLE(1);
|
||||
auto dLdWr = OUTPUT_VARIABLE(2);
|
||||
auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr;
|
||||
auto dLdhI = OUTPUT_VARIABLE(count++);
|
||||
auto dLdcI = OUTPUT_VARIABLE(count++);
|
||||
auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr;
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
|
||||
const Nd4jLong nIn = x->sizeAt(-1);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
// inputs validations
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||||
// initial output/cell validation
|
||||
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
|
||||
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||||
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||||
REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str());
|
||||
// peephole weights validation
|
||||
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||
if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str());
|
||||
|
||||
|
||||
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
|
||||
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||||
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||||
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||||
|
||||
std::vector<Nd4jLong> zShape = x->rankOf() == 1 ? std::vector<Nd4jLong>({4*nOut}) : std::vector<Nd4jLong>({bS, 4*nOut});
|
||||
|
||||
NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext());
|
||||
NDArray a = z.ulike();
|
||||
NDArray h = cI->ulike();
|
||||
NDArray c = cI->ulike();
|
||||
|
||||
helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);
|
||||
|
||||
helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(lstmLayerCellBp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(lstmLayerCellBp) {
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||||
|
||||
uint count = 3;
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||||
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||||
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
|
||||
|
||||
std::vector<Nd4jLong*> shapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
|
||||
|
||||
if(b != nullptr)
|
||||
shapes.push_back(b->getShapeInfo());
|
||||
|
||||
shapes.push_back(hI->getShapeInfo());
|
||||
shapes.push_back(cI->getShapeInfo());
|
||||
|
||||
if(Wp != nullptr)
|
||||
shapes.push_back(Wp->getShapeInfo());
|
||||
|
||||
return new ShapeList(shapes);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -149,6 +149,13 @@ namespace ops {
|
|||
DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2);
|
||||
#endif
|
||||
|
||||
#if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||
DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3);
|
||||
#endif
|
||||
#if NOT_EXCLUDED(OP_lstmLayerCell)
|
||||
DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3);
|
||||
#endif
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
|
@ -236,6 +243,11 @@ namespace ops {
|
|||
DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5);
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
#if NOT_EXCLUDED(OP_lstmLayer)
|
||||
DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5);
|
||||
#endif
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -22,7 +22,6 @@
|
|||
#define LIBND4J_LSTMLAYER_H
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/activations.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
@ -34,6 +33,20 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra
|
|||
const std::vector<float>& params,
|
||||
NDArray* h, NDArray* c);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// this auxiliary ff should be running before backprop
|
||||
void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||
const std::vector<float>& params,
|
||||
NDArray* z, NDArray* a, NDArray* h, NDArray* c);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||
const NDArray* dLdh, const NDArray* dLdc,
|
||||
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
|
||||
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||
|
@ -42,71 +55,11 @@ void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const ND
|
|||
NDArray* h, NDArray* hL, NDArray* cL);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
|
||||
|
||||
switch (opId) {
|
||||
case 0:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, z);
|
||||
break;
|
||||
case 1:
|
||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, z);
|
||||
break;
|
||||
case 2:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, z);
|
||||
break;
|
||||
case 3: {
|
||||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, z, &args);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, z);
|
||||
break;
|
||||
case 5:
|
||||
helpers::thresholdRelu(x.getContext(), x, alpha, z);
|
||||
break;
|
||||
case 6: {
|
||||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, z, &args);
|
||||
break;
|
||||
}
|
||||
case 7:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, z);
|
||||
break;
|
||||
case 8:
|
||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, z);
|
||||
break;
|
||||
case 9:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, z);
|
||||
break;
|
||||
case 10:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, z);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
|
||||
|
||||
if(dataFormat == 0 || dataFormat == 3)
|
||||
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
|
||||
|
||||
if(dataFormat == 1)
|
||||
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
|
||||
|
||||
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
|
||||
|
||||
if(dataFormat == 0 || dataFormat == 3)
|
||||
return t * bS + b; // TNS: shape [sL, bS, nIn]
|
||||
|
||||
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
|
||||
}
|
||||
void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp,
|
||||
const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
|
||||
const std::vector<float>& params, const bool forward,
|
||||
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp);
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -1441,7 +1441,7 @@ namespace simdOps {
|
|||
}
|
||||
|
||||
op_def static Z op(X d1) {
|
||||
return d1;
|
||||
return static_cast<Z>(d1);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2434,6 +2434,19 @@ namespace simdOps {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
class RELUDerivative {
|
||||
public:
|
||||
no_op_exec_special_same
|
||||
no_op_exec_special_same_cuda
|
||||
|
||||
op_def static Z op(X d1, Y d2, Z *params) {
|
||||
auto xt = static_cast<Z>(d1);
|
||||
auto xf = static_cast<Z>(d2);
|
||||
return xt > xf ? static_cast<Z>(1.f) : static_cast<Z>(0.f);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
class SXELogitsSmoother {
|
||||
public:
|
||||
|
|
|
@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) {
|
|||
auto z = result.at(0);
|
||||
ASSERT_TRUE(z->isEmpty());
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, test_empty_range_2) {
|
||||
|
@ -262,7 +262,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) {
|
|||
ASSERT_EQ(result.status(), Status::OK());
|
||||
//result.at(0)->printBuffer("Output");
|
||||
ASSERT_TRUE(exp1.equalsTo(result.at(0)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) {
|
||||
|
@ -286,7 +286,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) {
|
|||
ASSERT_EQ(result.status(), Status::OK());
|
||||
//result.at(0)->printBuffer("Output");
|
||||
ASSERT_TRUE(exp.equalsTo(result.at(0)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
|
||||
|
@ -312,7 +312,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
|
|||
//exp.printBuffer("Expect");
|
||||
//result.at(0)->printShapeInfo("Shape output");
|
||||
ASSERT_TRUE(exp.equalsTo(result.at(0)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) {
|
||||
|
@ -349,7 +349,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) {
|
|||
//result.at(2)->printBuffer("Symmetrized2");
|
||||
// ASSERT_TRUE(exp[i]->equalsTo(result.at(i)));
|
||||
ASSERT_TRUE(exp.equalsTo(result.at(2)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
|
||||
|
@ -369,7 +369,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) {
|
|||
//exp.printBuffer("EXPect symm3");
|
||||
// ASSERT_TRUE(exp[i]->equalsTo(result.at(i)));
|
||||
//ASSERT_TRUE(exp.equalsTo(result.at(0)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
|
||||
|
@ -398,7 +398,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
|
|||
//exp.printBuffer("EXPect symm3");
|
||||
// ASSERT_TRUE(exp[i]->equalsTo(result.at(i)));
|
||||
ASSERT_TRUE(exp4.equalsTo(res));
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, CellContains_test_1) {
|
||||
|
@ -420,7 +420,7 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) {
|
|||
//exp.printBuffer("EXPect symm3");
|
||||
// ASSERT_TRUE(exp[i]->equalsTo(result.at(i)));
|
||||
//ASSERT_TRUE(exp.equalsTo(result.at(0)));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
@ -712,7 +712,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_2) {
|
|||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) {
|
||||
|
@ -1109,6 +1109,7 @@ TEST_F(DeclarableOpsTests13, mergeavg_bp_1) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_1) {
|
||||
|
||||
|
@ -1200,7 +1201,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) {
|
|||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = false; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastH = true; // return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
@ -1398,7 +1399,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
|||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1640,7 +1641,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_7) {
|
|||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1718,7 +1719,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) {
|
|||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1805,7 +1806,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) {
|
|||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1890,7 +1891,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) {
|
|||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1970,7 +1971,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) {
|
|||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -2061,10 +2062,528 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
|||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = false; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::SUM, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = true; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
|
||||
|
||||
const int sL = 4;
|
||||
const int bS = 3;
|
||||
const int nIn = 3;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = false; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 3;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = false; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = false; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = false; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 2; // [bS, nIn, sL]
|
||||
const int directionMode = 2; // bidirectional sum
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = false; // dLdh per each time step
|
||||
const auto retLastH = true; // output at last time step
|
||||
const auto retLastC = false; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 3; // bidirectional concat
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = false; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) {
|
||||
|
||||
const int sL = 3;
|
||||
const int bS = 2;
|
||||
const int nIn = 2;
|
||||
const int nOut = 2;
|
||||
|
||||
const int dataFormat = 3; // [sL, bS, nIn]
|
||||
const int directionMode = 4; // bidirectional extra output dim
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // dLdh per each time step
|
||||
const auto retLastH = false; // output at last time step
|
||||
const auto retLastC = false; // cells state at last time step
|
||||
|
||||
const double cellClip = 0.5; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-2,0.1);
|
||||
hI.linspace(-1.5,0.1);
|
||||
cI.linspace(0.7,-0.1);
|
||||
Wx.linspace(1,-0.1);
|
||||
Wr.linspace(-1,0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayer opFF;
|
||||
sd::ops::lstmLayer_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, batchnorm_test1) {
|
||||
|
@ -2091,7 +2610,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test1) {
|
|||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
@ -2233,7 +2752,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test6) {
|
|||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
@ -2345,7 +2864,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) {
|
|||
ASSERT_TRUE(expected.isSameShape(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2387,7 +2906,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) {
|
|||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -2642,7 +3161,7 @@ return;
|
|||
ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB));
|
||||
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -844,5 +844,78 @@ TEST_F(PlaygroundTests, my) {
|
|||
printf("time: %i \n", time);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
// const int nIn = 8;
|
||||
// const int nOut = 6;
|
||||
|
||||
const float cellClip = 1.1; // clipping value
|
||||
const Nd4jLong gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
|
||||
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
|
||||
const Nd4jLong cellAct = 0; // tanh activation for cell state
|
||||
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
|
||||
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
|
||||
const Nd4jLong outAct = 0; // tanh activation for output
|
||||
const float outAlpha = 0; // alpha value for output activation, not required for tanh
|
||||
const float outBeta = 0; // beta value for output activation, not required for tanh
|
||||
|
||||
NDArray x ('c', {bS, nIn}, sd::DataType::DOUBLE);
|
||||
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
NDArray dLdc('c', {bS, nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
// NDArray x ('c', {nIn}, sd::DataType::DOUBLE);
|
||||
// NDArray hI('c', {nOut}, sd::DataType::DOUBLE);
|
||||
// NDArray cI('c', {nOut}, sd::DataType::DOUBLE);
|
||||
// NDArray dLdh('c', {nOut}, sd::DataType::DOUBLE);
|
||||
// NDArray dLdc('c', {nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray b ('c', {4*nOut}, sd::DataType::DOUBLE);
|
||||
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
|
||||
|
||||
x.linspace(-4,1);
|
||||
hI.linspace(-2.5,0.5);
|
||||
cI.linspace(-3,0.5);
|
||||
Wx.linspace(0,0.1);
|
||||
Wr.linspace(3,-0.1);
|
||||
Wp.linspace(0.2,0.2);
|
||||
b.linspace(1,-0.15);
|
||||
|
||||
// x.assign(1.);
|
||||
// hI.assign(2.);
|
||||
// cI.assign(3.);
|
||||
// Wx.assign(0.5);
|
||||
// Wr.assign(0.5);
|
||||
// Wp.assign(0.75);
|
||||
// b.assign(0.7);
|
||||
|
||||
std::vector<double> tArgs = {cellClip};
|
||||
std::vector<Nd4jLong> iArgs = {gateAct, cellAct, outAct};
|
||||
|
||||
// std::vector<bool> bArgs = {false, false};
|
||||
// const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, iArgs, bArgs);
|
||||
|
||||
std::vector<bool> bArgs = {true, true};
|
||||
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
|
||||
|
||||
sd::ops::lstmLayerCell opFF;
|
||||
sd::ops::lstmLayerCellBp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, true, true, true});
|
||||
}
|
||||
|
||||
|
||||
|
||||
*/
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue