From 23e4aa99ad8b254c7dd32a632be4658c540bd5af Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Mon, 13 Apr 2020 13:21:51 +0300 Subject: [PATCH] Shyrma lstm layer bp (#370) * - start working on bp for lstm Signed-off-by: Yurii * - further working on bp for lstmLayer Signed-off-by: Yurii * - minor change Signed-off-by: Yurii * - further work on bp for lstmLayer 2 Signed-off-by: Yurii * - further work on bp for lstmLayer 3 Signed-off-by: Yurii * - further work on bp for lstmLayer 4 Signed-off-by: Yurii * - further work on bp for lstmLayer 5 Signed-off-by: Yurii * - further work on bp for lstmLayer 6 Signed-off-by: Yurii * - further work on bp for lstmLayer 7 Signed-off-by: Yurii * - further work on bp for lstmLayer 8 Signed-off-by: Yurii * - further work on bp for lstmLayer 9 Signed-off-by: Yurii * - provide lstmLayerCell and lstmLayerCellBp as separate CUSTOM_OPs Signed-off-by: Yurii * - testing and fixing lstmLayerCellBp helper Signed-off-by: Yurii * - implement lstmLayerCellBp as separate op Signed-off-by: Yurii * - implement lstmLayerBp as separate op (not tested) Signed-off-by: Yurii * - fixing calculations of dLdWp and dLdb in lstmLayerCellBp Signed-off-by: Yurii * - further work on bp for lstmLayer 10 Signed-off-by: Yurii * - fixing typo in lstmLayerTimeLoop Signed-off-by: Yurii * - forgot to perform clipping of c array and calculate corresponding derivative in lstmLayerCellBp Signed-off-by: Yurii * - further work on bp for lstmLayer 10 Signed-off-by: Yurii * - testing and fixing bugs in lstmLayer_bp op 1 Signed-off-by: Yurii * - testing and fixing bugs in lstmLayer_bp op 2 Signed-off-by: Yurii * - turn off heavy tests for cuda for lstmLayer_bp op Signed-off-by: Yurii * - forgot to nullify gradients at eliminated time steps (when sequnce length array is present ) Signed-off-by: Yurii --- libnd4j/include/array/NDArray.hXX | 1 - libnd4j/include/array/cpu/NDArrayLambda.hpp | 2 +- libnd4j/include/helpers/GradCheck.h | 6 +- libnd4j/include/helpers/cpu/MmulHelper.cpp | 8 +- libnd4j/include/helpers/impl/GradCheck.cpp | 52 +- libnd4j/include/loops/legacy_ops.h | 3 +- .../generic/nn/recurrent/lstmLayer.cpp | 444 ++++- .../generic/nn/recurrent/lstmLayerCell.cpp | 339 ++++ .../ops/declarable/headers/recurrent.h | 12 + .../ops/declarable/helpers/impl/lstmLayer.cpp | 1460 ++++++++++++++++- .../ops/declarable/helpers/lstmLayer.h | 85 +- libnd4j/include/ops/ops.h | 15 +- .../layers_tests/DeclarableOpsTests13.cpp | 563 ++++++- .../layers_tests/PlaygroundTests.cpp | 73 + 14 files changed, 2896 insertions(+), 167 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 1caae85a4..7756fb7ae 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -403,7 +403,6 @@ NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::Launch ///////////////////////////////////////////////////////////////////////// // u8 string constructors -///////////////////////////////////////////////////////////////////////// NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { if (!DataTypeUtils::isS(dtype)) { diff --git a/libnd4j/include/array/cpu/NDArrayLambda.hpp b/libnd4j/include/array/cpu/NDArrayLambda.hpp index 8bced3de4..bd8742288 100644 --- a/libnd4j/include/array/cpu/NDArrayLambda.hpp +++ b/libnd4j/include/array/cpu/NDArrayLambda.hpp @@ -10,7 +10,7 @@ void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std:: throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { - nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); + nd4j_printf("applyTriplewiseLambda requires all operands to have the same shape\n",""); throw std::runtime_error("Shapes mismach"); } diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index 0d184a5a1..f5fd1f3df 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -47,13 +47,13 @@ class ND4J_EXPORT GradCheck { * opBP - back propagation operation * argsHolderFF - argument holder for feed forward operation * argsHolderBP - argument holder for back propagation operation - * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty array which means to check all arrays + * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM + * outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent */ static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM); - + const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector& outArrsFFIdx = {}); }; diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 62d8153ef..73f3e54bd 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -372,16 +372,16 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con int xLenDim(0), yLenDim(0); if(!shape::isCommonVector(X->getShapeInfo(), xLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !"); + throw std::runtime_error("MmulHelper::dot: X array must be vector !"); if(!shape::isCommonVector(Y->getShapeInfo(), yLenDim)) - throw std::runtime_error("MmulHelper::dot cuda: Y array must be vector !"); + throw std::runtime_error("MmulHelper::dot: Y array must be vector !"); if(Z != nullptr && !Z->isScalar()) - throw std::runtime_error("MmulHelper::dot cuda: Z array must be scalar !"); + throw std::runtime_error("MmulHelper::dot: Z array must be scalar !"); const auto length = X->lengthOf(); if(Y->lengthOf() != length) - throw std::runtime_error("MmulHelper::dot cuda: lengths of input vectors are different !"); + throw std::runtime_error("MmulHelper::dot: lengths of input vectors are different !"); if(Z == nullptr) Z = new NDArray(DataTypeUtils::pickPairwiseResultType(X->dataType(), Y->dataType()), X->getContext()); diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index 2643a7b6d..f3daa798c 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector& ////////////////////////////////////////////////////////////////////////// bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss ) { + const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss, const std::vector& outArrsFFIdx) { const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP @@ -82,12 +82,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons int numOutArrs = outArrsFF.size(); double scorePlus = 0.; - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); + if(!outArrsFFIdx.empty()) { + for(const auto& k : outArrsFFIdx) { // loop through independent output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); + } + } + else { + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); + } } // subtract epsilon, feed forward @@ -95,12 +106,23 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons outArrsFF = opFF.execute(argsHolderFF); double scoreMinus = 0.; - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); + if(!outArrsFFIdx.empty()) { + for(const auto& k : outArrsFFIdx) { // loop through independent output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); + } + } + else { + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); + } } // restore initial element value @@ -120,7 +142,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons throw std::runtime_error(""); } - // printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad); + // printf("%lld: num = %.15f, ana = %.15f\n", j, numericalGrad, analyticGrad); // calculate relative error double relError; @@ -134,7 +156,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons if(math::nd4j_abs(analyticGrad - numericalGrad) < MINABSERR) continue; - printf("numericalGrad = %f, analyticGrad = %f \n", numericalGrad, analyticGrad); + printf("numericalGrad = %.15f, analyticGrad = %.15f \n", numericalGrad, analyticGrad); printf("GradCheck::checkGrad: got RELERROR = %f > MAXRELERROR(%f) for input array # %i and its element at position %lld ! \n", relError, MAXRELERR, i, j); return false; } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 95f83be1a..001f8806c 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -253,7 +253,8 @@ (45, ReversePow), \ (46, DivideNoNan), \ (47, IGamma), \ - (48, IGammac) + (48, IGammac), \ + (49, RELUDerivative) diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp index 3a02b8a70..8637fe990 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp @@ -24,10 +24,10 @@ #include #include + namespace sd { namespace ops { - ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { @@ -43,7 +43,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't + // ct = clip(ft ◦ ct-1 + it ◦ c't) // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) // ht = ot ◦ tanh(ct) @@ -72,26 +72,26 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // 2) [2, nOut, 4*nOut] when directionMode >= 2 // ******* - // peephole weights Wp: + // peephole weights Wp, optional: // 1) [3*nOut] when directionMode < 2 // 2) [2, 3*nOut] when directionMode >= 2 // ******* - // biases b: + // biases b, optional: // 1) [4*nOut] when directionMode < 2 // 2) [2, 4*nOut] when directionMode >= 2 // ******* - // sequence length array seqLen: - // 1) [bS] always + // sequence length array seqLen, optional: + // 1) [bS] // ******* - // initial output hI: + // initial output hI, optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 // ******* - // initial cell state cI (same shape as in hI): + // initial cell state cI (same shape as in hI), optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 @@ -99,7 +99,7 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // OUTPUTS: // ******* - // output h: + // output h, optional: // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 @@ -109,19 +109,19 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 // ******* - // output at last step hL: + // output at last step hL, optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 // ******* - // cell state at last step cL (same shape as in hL): + // cell state at last step cL (same shape as in hL), optional: // 1) [bS, nOut] when directionMode < 2 // 2) [2, bS, nOut] when directionMode >= 2 // !!! dimension 4*nOut implies order it, ft, c't, ot // !!! dimension 3*nOut implies order it, ft, ot - const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus @@ -135,8 +135,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided const auto hasPH = B_ARG(4); // indicates whether peephole connections are present const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} - const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) - const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; @@ -176,8 +176,8 @@ CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { // evaluate dimensions const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; // inputs validations @@ -323,9 +323,9 @@ DECLARE_SHAPE_FN(lstmLayer) { const auto Wr = INPUT_VARIABLE(2); // recurrent weights // evaluate dimensions - const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) ); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; DataType type; @@ -398,6 +398,412 @@ DECLARE_SHAPE_FN(lstmLayer) { } +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen, optional: + // 1) [bS] + + // ******* + // initial output hI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. output dLdh, optional: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // gradient vs output at last time step dLdhL, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + + // OUTPUTS: + + // ******* + // gradient vs. input dLdx: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // gradient vs. input weights dLdWx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. recurrent weights dLdWr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // gradient vs. peephole weights dLdWp, optional: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // gradient vs. biases dLdb, optional: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // gradient vs. sequence length array dLdsL, optional (do not calculate it!!!): + // 1) [bS] always + + // ******* + // gradient vs. initial output dLdhI, optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1} + const auto retLastH = B_ARG(6); // indicates whether gradient vs. output at last time step (dLdhL) is given + const auto retLastC = B_ARG(7); // indicates whether gradient vs. cell state at last time step (dLdcL) is given + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) !"); + REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL !"); + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = retFullSeq ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output + const auto dLdhL = retLastH ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. output at last time step + const auto dLdcL = retLastC ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. cell state at last time step + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. input + auto dLdWx = OUTPUT_NULLIFIED(1); // gradient vs. input weights + auto dLdWr = OUTPUT_NULLIFIED(2); // gradient vs. recurrent weights + auto dLdb = hasBiases ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. biases + auto dLdsL = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // gradient vs. seqLen vector, we don't calculate it !!! + auto dLdhI = hasInitH ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial output + auto dLdcI = hasInitC ? OUTPUT_NULLIFIED(count++) : nullptr; // gradient vs. initial cell state + auto dLdWp = hasPH ? OUTPUT_NULLIFIED(count) : nullptr; // gradient vs. peephole weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(2); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if(directionMode < 2) { // no bidirectional + + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if(dLdhL != nullptr && (dLdhL->rankOf() != 2 || dLdhL->sizeAt(0) != bS || dLdhL->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if(dLdcL != nullptr && (dLdcL->rankOf() != 2 || dLdcL->sizeAt(0) != bS || dLdcL->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + } + else { // bidirectional + // Wx validation + if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + // gradient vs. output at last time step validation + if(dLdhL != nullptr && (dLdhL->rankOf() != 3 || dLdhL->sizeAt(0) != 2 || dLdhL->sizeAt(1) != bS || dLdhL->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdhL).c_str()); + // gradient vs. cell state at last time step validation + if(dLdcL != nullptr && (dLdcL->rankOf() != 3 || dLdcL->sizeAt(0) != 2 || dLdcL->sizeAt(1) != bS || dLdcL->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(dLdcL).c_str()); + } + + // gradient vs. output validation + if(dLdh) { + int factor = directionMode <= 2 ? 1 : 2; + std::vector expdLdhShape; + if(dataFormat == 0) expdLdhShape = std::vector{sL, bS, factor*nOut}; + else if(dataFormat == 1) expdLdhShape = std::vector{bS, sL, factor*nOut}; + else if(dataFormat == 2) expdLdhShape = std::vector{bS, factor*nOut, sL}; + else expdLdhShape = std::vector{sL, 2, bS, nOut}; + REQUIRE_TRUE(dLdh->isSameShape(expdLdhShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expdLdhShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + } + + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + if(directionMode == 0) { // forward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, true, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + } + else if(directionMode == 1) { // backward + + helpers::lstmLayerTimeLoopBp(x, Wx, Wr, b, seqLen, hI, cI, Wp, dLdh, dLdhL, dLdcL, params, false, dLdx, dLdWx, dLdWr, dLdb, dLdhI, dLdcI, dLdWp); + } + else { // bidirectional + + NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); + NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); + NDArray dLdWxFwd = (*dLdWx)({0,1, 0,0, 0,0}); + NDArray dLdWxBwd = (*dLdWx)({1,2, 0,0, 0,0}); + + NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); + NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); + NDArray dLdWrFwd = (*dLdWr)({0,1, 0,0, 0,0}); + NDArray dLdWrBwd = (*dLdWr)({1,2, 0,0, 0,0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *dLdhFwd(nullptr), *dLdhBwd(nullptr), *dLdhLFwd(nullptr), *dLdhLBwd(nullptr), *dLdcLFwd(nullptr), *dLdcLBwd(nullptr), + *dLdWpFwd(nullptr), *dLdWpBwd(nullptr), *dLdbFwd(nullptr), *dLdbBwd(nullptr), + *dLdhIFwd(nullptr), *dLdhIBwd(nullptr), *dLdcIFwd(nullptr), *dLdcIBwd(nullptr); + + if(Wp) { + WpFwd = new NDArray((*Wp)({0,1, 0,0})); + WpBwd = new NDArray((*Wp)({1,2, 0,0})); + dLdWpFwd = new NDArray((*dLdWp)({0,1, 0,0})); + dLdWpBwd = new NDArray((*dLdWp)({1,2, 0,0})); + } + if(b) { + bFwd = new NDArray((*b)({0,1, 0,0})); + bBwd = new NDArray((*b)({1,2, 0,0})); + dLdbFwd = new NDArray((*dLdb)({0,1, 0,0})); + dLdbBwd = new NDArray((*dLdb)({1,2, 0,0})); + } + if(hI) { + hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); + hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); + dLdhIFwd = new NDArray((*dLdhI)({0,1, 0,0, 0,0})); + dLdhIBwd = new NDArray((*dLdhI)({1,2, 0,0, 0,0})); + } + if(cI) { + cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); + cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); + dLdcIFwd = new NDArray((*dLdcI)({0,1, 0,0, 0,0})); + dLdcIBwd = new NDArray((*dLdcI)({1,2, 0,0, 0,0})); + } + if(dLdhL) { + dLdhLFwd = new NDArray((*dLdhL)({0,1, 0,0, 0,0})); + dLdhLBwd = new NDArray((*dLdhL)({1,2, 0,0, 0,0})); + } + if(dLdcL) { + dLdcLFwd = new NDArray((*dLdcL)({0,1, 0,0, 0,0})); + dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0})); + } + + // FIXME looks like sum (directionMode == 2) is impossible for backprop + if(dLdh) { + if(directionMode == 2) { // sum + REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !"); + // dLdhFwd = dLdh; + // dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content + } + else if(directionMode == 3) { // concat + dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0})); + dLdhBwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, nOut,2*nOut}) : (*dLdh)({0,0, nOut,2*nOut, 0,0})); + } + else { // directionMode == 4 + dLdhFwd = new NDArray((*dLdh)({0,0, 0,1, 0,0, 0,0})); + dLdhBwd = new NDArray((*dLdh)({0,0, 1,2, 0,0, 0,0})); + } + } + + + + helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd); + NDArray dLdxBwd = dLdx->ulike(); + helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd); + + *dLdx += dLdxBwd; + + delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd; + delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; + delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd; + delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd; + + if(dLdhFwd != dLdh) + delete dLdhFwd; + } + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayer_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +DECLARE_SHAPE_FN(lstmLayer_bp) { + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + + int count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + std::vector outShapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()}; + + if(b != nullptr) + outShapes.push_back(b->getShapeInfo()); + if(seqLen != nullptr) + outShapes.push_back(seqLen->getShapeInfo()); + if(hI != nullptr) + outShapes.push_back(hI->getShapeInfo()); + if(cI != nullptr) + outShapes.push_back(cI->getShapeInfo()); + if(Wp != nullptr) + outShapes.push_back(Wp->getShapeInfo()); + + return new ShapeList(outShapes); +} + } } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp new file mode 100644 index 000000000..46f32e399 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp @@ -0,0 +1,339 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#if NOT_EXCLUDED(OP_lstmLayerCell) + +#include +#include + +namespace sd { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // biases b (optional): [4*nOut] + // peephole weights Wp (optional): [3*nOut] + + // OUTPUTS: + // current output h: [bS, nOut] or [nOut] + // current cell state c: [bS, nOut] or [nOut] + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasPH = B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !"); + + auto h = OUTPUT_VARIABLE(0); + auto c = OUTPUT_VARIABLE(1); + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + + std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c); + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayerCell) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + + +DECLARE_SHAPE_FN(lstmLayerCell) { + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + + uint count = hasBiases ? 4 : 3; + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count); // initial cell state + + return new ShapeList({hI->getShapeInfo(), cI->getShapeInfo()}); +} + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = clip(ft ◦ ct-1 + it ◦ c't) + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + // input x: [bS, nIn] or [nIn] + // input weights Wx: [nIn, 4*nOut] + // recurrent weights Wr: [nOut, 4*nOut] + // initial (previous) output hI: [bS, nOut] or [nOut] + // initial (previous) cell state cI: [bS, nOut] or [nOut] + // gradient wrt output dLdh: [bS, nOut] or [nOut] + // gradient wrt cell state dLdc: [bS, nOut] or [nOut] + // peephole weights Wp (optional): [3*nOut] + // biases b (optional): [4*nOut] + + // OUTPUTS: + // gradient wrt x dLdx: [bS, nIn] or [nIn] + // gradient wrt Wx dLdWx: [nIn, 4*nOut] + // gradient wrt Wr dLdWr: [nOut, 4*nOut] + // gradient wrt hI dLdhI: [bS, nOut] or [nOut] + // gradient wrt cI dLdcI: [bS, nOut] or [nOut] + // gradient wrt b dLdb (optional): [4*nOut] + // gradient wrt Wp dLdWp (optional): [3*nOut] + + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(1); // activation for cell state (c) + const auto outAct = INT_ARG(2); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasPH = B_ARG(1); // indicates whether peephole connections are present + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output + + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !"); + + count = 3; + auto dLdx = OUTPUT_VARIABLE(0); + auto dLdWx = OUTPUT_VARIABLE(1); + auto dLdWr = OUTPUT_VARIABLE(2); + auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr; + auto dLdhI = OUTPUT_VARIABLE(count++); + auto dLdcI = OUTPUT_VARIABLE(count++); + auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr; + + // evaluate dimensions + const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str()); + // initial output/cell validation + std::vector exphIcIShape = x->rankOf() == 1 ? std::vector{nOut} : std::vector{bS, nOut}; + REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str()); + if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str()); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str()); + if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str()); + + + std::vector params = {static_cast(0)/*ignore*/, static_cast(0)/*ignore*/, static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + std::vector zShape = x->rankOf() == 1 ? std::vector({4*nOut}) : std::vector({bS, 4*nOut}); + + NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext()); + NDArray a = z.ulike(); + NDArray h = cI->ulike(); + NDArray c = cI->ulike(); + + helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); + + helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayerCellBp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + + +DECLARE_SHAPE_FN(lstmLayerCellBp) { + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasPH = B_ARG(1); // indicates whether peephole connections are present + + uint count = 3; + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = INPUT_VARIABLE(count++); // initial output + const auto cI = INPUT_VARIABLE(count++); // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights + + std::vector shapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()}; + + if(b != nullptr) + shapes.push_back(b->getShapeInfo()); + + shapes.push_back(hI->getShapeInfo()); + shapes.push_back(cI->getShapeInfo()); + + if(Wp != nullptr) + shapes.push_back(Wp->getShapeInfo()); + + return new ShapeList(shapes); +} + +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index 55138bb60..dd219867f 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -149,6 +149,13 @@ namespace ops { DECLARE_CUSTOM_OP(lstmCell, 8, 2, false, 3, 2); #endif + #if NOT_EXCLUDED(OP_lstmLayerCell) + DECLARE_CUSTOM_OP(lstmLayerCell, 5, 2, false, 1, 3); + #endif + #if NOT_EXCLUDED(OP_lstmLayerCell) + DECLARE_CUSTOM_OP(lstmLayerCellBp, 7, 5, false, 1, 3); + #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -236,6 +243,11 @@ namespace ops { DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); #endif + ////////////////////////////////////////////////////////////////////////// + #if NOT_EXCLUDED(OP_lstmLayer) + DECLARE_CUSTOM_OP(lstmLayer_bp, 4, 1, false, 1, 5); + #endif + ////////////////////////////////////////////////////////////////////////// /** diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 435a3e32d..9fce17c4b 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -27,19 +28,215 @@ #include +#include +#include #include +#include // #include // #include // #include // #include // #include // #include -// #include + namespace sd { namespace ops { namespace helpers { +////////////////////////////////////////////////////////////////////////// +static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { + + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::Tanh, z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELU, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::Sigmoid, z); + break; + case 3: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::Affine, z, &args); + break; + } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); + break; + case 5: + thresholdRelu(x.getContext(), x, alpha, z); + break; + case 6: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoid, z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELU, alpha, z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSign, z); + break; + case 10: + (const_cast(x)).applyTransform(transform::SoftPlus, z); + break; + default: + throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + } +} + +////////////////////////////////////////////////////////////////////////// +static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { + + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::TanhDerivative, z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELUDerivative, 0, z); + break; + case 2: + (const_cast(x)).applyTransform(transform::SigmoidDerivative, z); + break; + case 3: { + z = alpha; + break; + } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELUDerivative, alpha, z); + break; + case 5: + (const_cast(x)).applyScalar(scalar::RELUDerivative, alpha, z); + break; + case 6: { + auto func = PRAGMA_THREADS_FOR { + for(Nd4jLong i = start; i < stop; ++i) { + auto val = beta * x.e(i); + z.p(i, alpha * beta * (1.f - sd::math::nd4j_tanh(val) * sd::math::nd4j_tanh(val))); + } + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoidDerivative, z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELUDerivative, alpha, z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSignDerivative, z); + break; + case 10: { + auto func = PRAGMA_THREADS_FOR { + for(Nd4jLong i = start; i < stop; ++i) { + auto val = sd::math::nd4j_exp(x.e(i)); + z.p(i, val / (1.f + val)); + } + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); + break; + } + default: + throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + } +} + +////////////////////////////////////////////////////////////////////////// +// FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal +static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) { + + if(clipVal == 0) + return; + + auto func = PRAGMA_THREADS_FOR { + for(Nd4jLong i = start; i < stop; ++i) { + const auto val = c.e(i); + if(val == -clipVal || val == clipVal) { + z0.p(i, 0.f); + z1.p(i, 0.f); + z2.p(i, 0.f); + z3.p(i, 0.f); + } + } + }; + samediff::Threads::parallel_for(func, 0, c.lengthOf()); +} + +////////////////////////////////////////////////////////////////////////// +static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { + + if(dataFormat == 0 || dataFormat == 3) + return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] + + if(dataFormat == 1) + return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] + + return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] +} + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { + + if(dataFormat == 0 || dataFormat == 3) + return t * bS + b; // TNS: shape [sL, bS, nIn] + + return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] +} + +////////////////////////////////////////////////////////////////////////// +// x{M,K} x y{K,N} = z{M,N}, dzdy{K,N,M,N} - Jacobian derivative -> if x.rankOf() == 2 +// x{K} x y{K,N} = z{N}, dzdy{K,N,N} - Jacobian derivative -> if x.rankOf() == 1 +static NDArray mmulJacobianWeightsDeriv(const int nOut, const NDArray& x) { + + std::vector outShape = x.rankOf() == 1 ? std::vector({x.sizeAt(0), nOut, nOut}) : std::vector({x.sizeAt(1), nOut, x.sizeAt(0), nOut}); + + NDArray dzdy(x.ordering(), outShape, x.dataType(), x.getContext()); + + if(x.rankOf() == 1) { + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + if(i1 == i2) + dzdy.p(i0,i1,i2, x.e(i0)); + else + dzdy.p(i0,i1,i2, 0); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); + } + else { + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (auto i3 = 0; i3 < dzdy.sizeAt(3); ++i3) { + if(i1 == i3) + dzdy.p(i0,i1,i2,i3, x.e(i2,i0)); + else + dzdy.p(i0,i1,i2,i3, 0); + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); + } + + return dzdy; +} ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, @@ -47,25 +244,27 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const std::vector& params, NDArray* h, NDArray* c) { + // * -> means element-wise multiplication + // ^ -> means matrix multiplication /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // it = σ(Wxi * xt + Wri * ht-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + bo) - // ht = ot ◦ tanh(ct) + // it = σ(Wxi ^ xt + Wri ^ ht-1 + bi) + // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + bf) + // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo ^ xt + Wro ^ ht-1 + bo) + // ht = ot * tanh(ct) // equations (peephole connections are present) - // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) - // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) - // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) - // ct = ft ◦ ct-1 + it ◦ c't - // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) - // ht = ot ◦ tanh(ct) + // it = σ(Wxi ^ xt + Wri ^ ht-1 + Wpi * ct-1 + bi) + // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + Wpf * ct-1 + bf) + // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // ct = ft * ct-1 + it * c't + // ot = σ(Wxo ^ xt + Wro ^ ht-1 + Wpo * ct + bo) + // ht = ot * tanh(ct) // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus @@ -91,8 +290,8 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // Wx - input weights [nIn, 4*nOut] // Wr - recurrent weights [nOut, 4*nOut] // b - biases [4*nOut], optional, may be nullptr - // hI - previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // Wp - peephole weights [3*nOut], optional, may be nullptr // OUTPUTS: @@ -109,24 +308,24 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // add biases if they are given if(b != nullptr) - z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut] + z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] - auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut] - auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut] - auto zc = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut] - auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut] + auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) // peephole connections for input and forget gates if(Wp != nullptr) { - zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] - zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) } applyActivation(zi, params[3], params[4], params[5], zi); // inplace applyActivation(zf, params[3], params[4], params[5], zf); // inplace - applyActivation(zc, params[6], params[7], params[8], zc); // inplace + applyActivation(zg, params[6], params[7], params[8], zg); // inplace - c->assign(zf * *cI + zi * zc); // [bS, nOut] ◦ [bS, nOut] + [bS, nOut] ◦ [bS, nOut] = [bS, nOut] + c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation if(params[2] != 0) @@ -134,15 +333,300 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // peephole connections for output gate if(Wp != nullptr) - zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [nOut] ◦ [bS, nOut] = [bS, nOut] + zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) applyActivation(zo, params[3], params[4], params[5], zo); applyActivation(*c, params[9], params[10], params[11], *h); - *h *= zo; // [bS, nOut] ◦ [bS, nOut] + *h *= zo; // [bS, nOut] * [bS, nOut](or[nOut]) } +////////////////////////////////////////////////////////////////////////// +// this auxiliary ff should be running before backprop +void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* z, NDArray* a, NDArray* h, NDArray* c) { + + // z - zi, zf, zg, zo + // a - i, f, g, o + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] + //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] + // add biases if they are given + if(b != nullptr) + *z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) + + auto i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) + auto f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) + auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) + auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) + + // peephole connections for input and forget gates + if(Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + } + + applyActivation(zi, params[3], params[4], params[5], i); + applyActivation(zf, params[3], params[4], params[5], f); + applyActivation(zg, params[6], params[7], params[8], g); + + c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) + + // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation + if(params[2] != 0) + c->applyScalar(scalar::LstmClip, params[2], *c); + + // peephole connections for output gate + if(Wp != nullptr) + zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + + applyActivation(zo, params[3], params[4], params[5], o); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= o; // [bS, nOut] * [bS, nOut](or[nOut]) +} + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdc, + const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + + /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // zi = x ^ Wxi + hI ^ Wri + bi + // zf = x ^ Wxf + hI ^ Wrf + bf + // zg = x ^ Wxg + hI ^ Wrg + bg + // zo = x ^ Wxo + hI ^ Wro + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // equations (peephole connections are present) + // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi + // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf + // zg = x ^ Wxg + hI ^ Wrg + bg + // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo + // i = act(zi) + // f = act(zf) + // g = actC(zg) + // o = act(zo) + // c = clip(f * cI + i * g) + // h = o * actH(c) + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdc - loss derivative with respect to c, [bS, nOut] or [nOut] if seqLen != nullptr + // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] + + // OUTPUTS: + // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr + // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] + // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] + // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] + // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] + + // !!! dimension 4*nOut implies order i, f, g, o + // !!! dimension 3*nOut implies order i, f, o + + // dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut] + // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut] + + // dLdhI += dLdh; [bS, nOut] + // dLdcI += dLdhI * dhdc; [bS, nOut] + + // dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut]) + // dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut]) + // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) + // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) + + // dLdx = dLdzi^WxiT + dLdzf^WxfT + dLdzg^WxgT + dLdzo^WxoT, [bS, nIn] + // dLdhI = dLdzi^WriT + dLdzf^WrfT + dLdzg^WrgT + dLdzo^WroT, [bS, nOut] + // dLdcI = dLdcI*dcdcI, [bS, nOut] + + // dLdWxi = xT^dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxf = xT^dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxg = xT^dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxo = xT^dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] + + // dLdWri = hIT^dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrf = hIT^dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrg = hIT^dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWro = hIT^dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] + + // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + // dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + const Nd4jLong nIn = x->sizeAt(-1); + + NDArray zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) + NDArray f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) + NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) + NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) + + NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut]) + NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0,0, 0, nOut}); + NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut, 2*nOut}) : dLdz({0,0, nOut, 2*nOut}); + NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut}); + NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut}); + + // dcdzi = dcdi*didzi, [bS, nOut](or[nOut]) + activationDeriv(zi, params[3], params[4], params[5], dLdzi); // didzi, inplace + dLdzi *= g; // dcdi = g*clipDeriv + + // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut]) + activationDeriv(zf, params[3], params[4], params[5], dLdzf); // dfdzf, inplace + dLdzf *= *cI; // dcdf = cI*clipDeriv + + // dcdzg = dcde*dedzg, [bS, nOut](or[nOut]) + activationDeriv(zg, params[6], params[7], params[8], dLdzg); // dgdzg, inplace + dLdzg *= i; // dcdf = i*clipDeriv + + // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut]) + activationDeriv(zo, params[3], params[4], params[5], dLdzo); + NDArray temp = dLdzo.ulike(); + applyActivation(*c, params[9], params[10], params[11], temp); // actH(c), inplace + dLdzo *= temp; + + // dcdcI + NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut]) + + // take into account possible deposit from clipping derivative + clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI); + + // dhdc + NDArray dhdc = c->ulike(); + activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut] + dhdc *= o; + + if(Wp) { + dhdc += dLdzo*(*Wp)({2*nOut, 3*nOut}); + dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut}); // broadcast [bS, nOut] * nOut + ... + } + + if(dLdh) + *dLdhI += *dLdh; + if(dLdc) + *dLdcI += *dLdc; + else + *dLdcI += *dLdhI * dhdc; + + dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzg *= *dLdcI; // [bS, nOut](or[nOut]) + dLdzo *= *dLdhI; // [bS, nOut](or[nOut]) + + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] ) + + // dLdhI + NDArray WrT = Wr->transpose(); + MmulHelper::mmul(&dLdz, &WrT, dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] ) + + // dLdcI + dLdcI->assign(*dLdcI*dcdcI); // [bS, nOut](or[nOut]) + + if(x->rankOf() == 1) { + + NDArray xT = x->reshape(x->ordering(),{nIn, 1}); // [nIn] -> [nIn, 1] + NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1}); // [nOut] -> [nOut, 1] + NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut}); // [nOut] -> [1, 4*nOut] + + // dLdWx + *dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut] + } + else { + + // dLdWx + *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut] + + // dLdWr + *dLdWr += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut] + } + + // dLdb + if(b && x->rankOf() == 1) + *dLdb += dLdz; // [4*nOut] + else if(b) + *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut]; + + // dLdWp + if(Wp && x->rankOf() == 1) { + (*dLdWp)({ 0,nOut}) += std::move(dLdzi)*(*cI); // [nOut] + (*dLdWp)({ nOut,2*nOut}) += std::move(dLdzf)*(*cI); // [nOut] + (*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c); // [nOut] + } + else if(Wp) { + NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); + (std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({0,nOut}) += temp; + (std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({nOut,2*nOut}) += temp; + (std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] + (*dLdWp)({2*nOut,3*nOut}) += temp; + } +} ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, @@ -172,7 +656,7 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const int dataFormat = params[0]; const int directionMode = params[1]; - const Nd4jLong sL = x->sizeAt(dataFormat); + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; @@ -192,7 +676,7 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, auto ct = cL; if(!cL) - cL = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ct = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); auto ht = hL; if(!h && !hL) @@ -300,7 +784,8 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + if(limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } @@ -380,7 +865,8 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + if(limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } @@ -439,7 +925,8 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr - tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + if(limit != sL) + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } @@ -451,10 +938,915 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, delete c0Set; delete htSet; delete ctSet; + + if(!hI) + delete h0; + if(!cI) + delete c0; + if(!cL) + delete ct; + if(!h && !hL) + delete ht; +} + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const std::vector& params, const bool forward, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) { + + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr + // dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr + // dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr + + // OUTPUTS: + // dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL] + // dLdWx - gradient vs. input weights [nIn, 4*nOut] + // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut] + // dLdb - gradient vs. biases [4*nOut], optional, may be nullptr + // dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr + // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const int nOut = Wx->sizeAt(-1) / 4; + + auto dLdh0 = dLdhI; + if(!hI) + dLdh0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + + auto dLdc0 = dLdcI; + if(!cI) + dLdc0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + + NDArray z(x->ordering(), {sL, bS, 4*nOut}, x->dataType(), x->getContext()); + NDArray a = z.ulike(); + NDArray h(x->ordering(), {sL+1, bS, nOut}, x->dataType(), x->getContext()); + NDArray c = h.ulike(); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr), + *dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr); + + if(!seqLen) { + + dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut] + + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] + hSet = new ResultSet(h.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] + cSet = new ResultSet(c.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] + zSet = new ResultSet(z.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] + aSet = new ResultSet(a.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] + if(dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] + } + else { + + dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis + + xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] + hSet = new ResultSet(h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + cSet = new ResultSet(c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] + zSet = new ResultSet(z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + aSet = new ResultSet(a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] + + if(hI) + hISet = new ResultSet(hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if(cI) + cISet = new ResultSet(cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + + if(dLdh) + dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] + if(!dLdh && dLdhL) + dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + if(!dLdh && !dLdhL) + dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] + } + + + // loops + if(forward) { + + if(!seqLen) { // seqLen is absent + + if(hI) + h({0,1, 0,0, 0,0}).assign(hI); + else + h({0,1, 0,0, 0,0}).nullify(); + if(cI) + c({0,1, 0,0, 0,0}).assign(cI); + else + c({0,1, 0,0, 0,0}).nullify(); + + // ff + for (int t = 0; t < sL; ++t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1)); + + // bp + for (int t = sL-1; t >= 0; --t) { + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == sL-1 ? dLdhL : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-1 ? dLdcL : nullptr); + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdcc, + zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + } + } + else { // seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + continue; + } + + if(hI) + h({0,1, e,e+1, 0,0}).assign(hISet->at(e)); + else + h({0,1, e,e+1, 0,0}).nullify(); + if(cI) + c({0,1, e,e+1, 0,0}).assign(cISet->at(e)); + else + c({0,1, e,e+1, 0,0}).nullify(); + + // ff + for (int t = 0; t < limit; ++t) + lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params, + zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e)); + + // bp + for (int t = limit-1; t >= 0; --t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == limit-1 && dLdhL ? dLdhLSet->at(e) : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == limit-1 ? dLdcLSet->at(e) : nullptr); + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdcc, + zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, + dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + } + + if(limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + else { // backward or bidirectional + + if(!seqLen) { // backward or bidirectional, seqLen is absent + + if(hI) + h({sL,sL+1, 0,0, 0,0}).assign(hI); + else + h({sL,sL+1, 0,0, 0,0}).nullify(); + if(cI) + c({sL,sL+1, 0,0, 0,0}).assign(cI); + else + c({sL,sL+1, 0,0, 0,0}).nullify(); + + // ff + for (int t = sL-1; t >= 0; --t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t)); + + // bp + for (int t = 0; t < sL; ++t) { + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == 0 ? dLdhL : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcL : nullptr); + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdcc, + zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); + } + } + else if(directionMode == 1) { // backward, seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + continue; + } + + if(hI) + h({sL,sL+1, e,e+1, 0,0}).assign(hISet->at(e)); + else + h({sL,sL+1, e,e+1, 0,0}).nullify(); + if(cI) + c({sL,sL+1, e,e+1, 0,0}).assign(cISet->at(e)); + else + c({sL,sL+1, e,e+1, 0,0}).nullify(); + + // ff + for (int t = sL - 1; t >= sL-limit; --t) + lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, + zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); + + // bp + for (int t = sL-limit; t < sL; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == sL-limit && dLdhL ? dLdhLSet->at(e) : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-limit ? dLdcLSet->at(e) : nullptr); + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, + dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + } + + if(limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + else { // bidirectional mode, seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + continue; + } + + if(hI) + h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e)); + else + h({limit,limit+1, e,e+1, 0,0}).nullify(); + if(cI) + c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e)); + else + c({limit,limit+1, e,e+1, 0,0}).nullify(); + + // ff + for (int t = limit - 1; t >= 0; --t) + lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, + zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); + + // bp + for (int t = 0; t < limit; ++t) { + const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == 0 && dLdhL ? dLdhLSet->at(e) : nullptr); + const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcLSet->at(e) : nullptr); + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, + dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); + } + + if(limit != sL) + tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + + delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet; + delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet; + + if(!hI) + delete dLdh0; + if(!cI) + delete dLdc0; +} + + +} +} } -} -} -} +////////////////////////////////////////////////////////////////////////// +// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, +// const NDArray* b, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh, +// const std::vector& params, const bool firstIter, + +// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr, +// NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp, +// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { + +// /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ +// /** the objective is to provide math-readable code **/ + +// // equations (no peephole connections) +// // zi = x ^ Wxi + hI ^ Wri + bi +// // zf = x ^ Wxf + hI ^ Wrf + bf +// // zg = x ^ Wxg + hI ^ Wrg + bg +// // zo = x ^ Wxo + hI ^ Wro + bo +// // i = act(zi) +// // f = act(zf) +// // g = actC(zg) +// // o = act(zo) +// // c = clip(f * cI + i * g) +// // h = o * actH(c) + +// // equations (peephole connections are present) +// // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi +// // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf +// // zg = x ^ Wxg + hI ^ Wrg + bg +// // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo +// // i = act(zi) +// // f = act(zf) +// // g = actC(zg) +// // o = act(zo) +// // c = clip(f * cI + i * g) +// // h = o * actH(c) + +// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + +// // params[0] - dataFormat, ignore +// // params[1] - directionMode, ignore +// // params[2] - cell clipping value, if it = 0 then do not apply clipping + +// // params[3] - activation ID for input (i), forget (f) and output (o) gates +// // params[4] - alpha value for gates activation +// // params[5] - beta value for gates activation + +// // params[6] - activation ID for cell state (c) +// // params[7] - alpha value for cell state activation +// // params[8] - beta value for cell state activation + +// // params[9] - activation ID for output (h) +// // params[10] - alpha value for output activation +// // params[11] - beta value for output activation + +// // INPUTS: +// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr +// // Wx - input weights [nIn, 4*nOut] +// // Wr - recurrent weights [nOut, 4*nOut] +// // b - biases [4*nOut], optional, may be nullptr +// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr +// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr +// // Wp - peephole weights [3*nOut], optional, may be nullptr +// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr +// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr +// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr +// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr +// // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr +// // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr +// // dcIdb - derivative from previous time step, [4*nOut], optional, may be nullptr +// // dhIdb - derivative from previous time step, [4*nOut], optional, may be nullptr + +// // OUTPUTS: +// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr +// // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] +// // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] +// // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] +// // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr +// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] + +// // !!! dimension 4*nOut implies order i, f, g, o +// // !!! dimension 3*nOut implies order i, f, o + +// // dcdzi = dcdi*didzi +// // dcdzf = dcdf*dfdzf +// // dcdzg = dcdg*dgdzg +// // dhdzo = dhdo*dodzo + +// // dhdc = dhdc + Wp ? dhdzo*dzodc : 0 [bS, nOut] +// // factor = dLdh*dhdc [bS, nOut] +// // iFactor = factor*dcdzi [bS, nOut] +// // fFactor = factor*dcdzf [bS, nOut] +// // eFactor = factor*dcdzg [bS, nOut] +// // oFactor = *dLdh*dhdzo [bS, nOut] + +// // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0; +// // tempIFE = dcdzi^WriT + dcdzf^WrfT + dcdzg^WrgT +// // tempO = dhdzo^WroT + +// // dhIdcI = dhdc_from_previous_time_step + +// // dLdx = iFactor^WxiT + fFactor^WxfT + eFactor^WxgT + oFactor^WxoT, [bS, nIn] +// // dLdhI = iFactor^WriT + fFactor^WrfT + eFactor^WrgT + oFactor^WroT, [bS, nOut] +// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] + +// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] +// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] + +// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] +// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] + +// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcIdWpo = (0 + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] +// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] + +// const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + +// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr); +// if(Wp) { +// Wpi = new NDArray((*Wp)({0, nOut})); +// Wpf = new NDArray((*Wp)({nOut, 2*nOut})); +// Wpo = new NDArray((*Wp)({2*nOut, 3*nOut})); +// dhIdWpi = new NDArray((*dhIdWp)({0, nOut})); +// dhIdWpf = new NDArray((*dhIdWp)({nOut, 2*nOut})); +// dhIdWpo = new NDArray((*dhIdWp)({2*nOut, 3*nOut})); +// dcIdWpi = new NDArray((*dcIdWp)({0, nOut})); +// dcIdWpf = new NDArray((*dcIdWp)({nOut, 2*nOut})); +// dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut})); +// } + +// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr); +// if(b) { +// dhIdbi = new NDArray((*dhIdb)({0, nOut})); +// dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut})); +// dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut})); +// dhIdbo = new NDArray((*dhIdb)({3*nOut, 4*nOut})); +// dcIdbi = new NDArray((*dcIdb)({0, nOut})); +// dcIdbf = new NDArray((*dcIdb)({nOut, 2*nOut})); +// dcIdbg = new NDArray((*dcIdb)({2*nOut, 3*nOut})); +// dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut})); +// } + +// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr + +// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr +// NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr + +// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // [nOut, nIn] +// NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nIn] +// NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nIn] +// NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nIn] + +// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // [nOut, nOut] +// NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nOut] +// NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nOut] +// NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nOut] + +// // ***** feed forward step ***** // + +// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] +// //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] +// // add biases if they are given +// if(b) +// z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut]) + +// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) +// auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) +// auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) +// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) + +// // peephole connections for input and forget gates +// if(Wp) { +// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// zf += *cI * *Wpf; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) +// } + +// NDArray i = zi.ulike(); // [bS, nOut] +// NDArray f = zf.ulike(); // [bS, nOut] +// NDArray g = zg.ulike(); // [bS, nOut] +// applyActivation(zi, params[3], params[4], params[5], i); +// applyActivation(zf, params[3], params[4], params[5], f); +// applyActivation(zg, params[6], params[7], params[8], g); + +// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) + +// // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation +// if(params[2] != 0) +// c.applyScalar(scalar::LstmClip, params[2], c); + +// // peephole connections for output gate +// if(Wp) +// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) + +// NDArray o = zo.ulike(); // [bS, nOut](or[nOut]) +// applyActivation(zo, params[3], params[4], params[5], o); + +// // ***** back prop step ***** // + +// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, nOut, bS, nOut] (or [nIn, nOut, nOut]) +// NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut]) + +// // dodzo +// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zo, params[3], params[4], params[5], dodzo); + +// // dhdzo = dhdo*dodzo = actH(c)*dodzo +// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut]) +// applyActivation(c, params[9], params[10], params[11], dhdzo); // actH(c) +// hI->assign(o*dhdzo); +// dhdzo *= dodzo; + +// // dcdzi = dcdi*didzi +// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zi, params[3], params[4], params[5], dcdzi); // didzi +// dcdzi *= g; // dcdi = g*clipDeriv + +// // dcdzf = dcdf*dfdzf +// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zf, params[3], params[4], params[5], dcdzf); // dfdzf +// dcdzf *= *cI; // dcdf = cI*clipDeriv + +// // dcdzg = dcde*dedzg +// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut]) +// activationDeriv(zg, params[6], params[7], params[8], dcdzg); // dedzg +// dcdzg *= i; // dcdf = i*clipDeriv + +// // dcdcI +// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut]) + +// // take into account possible deposit from clipping derivative +// clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI); + +// // dzodc +// NDArray* dzodc = Wpo; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) + +// // dzidcI +// NDArray* dzidcI = Wpi; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) + +// // dzfdcI +// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) + +// // dhdc +// NDArray dhdc = c.ulike(); +// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, nOut] +// dhdc *= o; +// if(Wp) +// dhdc += dhdzo* *dzodc; + +// NDArray factor = *dLdh * dhdc; + +// NDArray iFactor = factor*dcdzi; // [bS, nOut](or[nOut]) +// NDArray fFactor = factor*dcdzf; // [bS, nOut](or[nOut]) +// NDArray eFactor = factor*dcdzg; // [bS, nOut](or[nOut]) +// NDArray oFactor = *dLdh *dhdzo; // [bS, nOut](or[nOut]) + +// NDArray tempC = dcdcI; +// if(Wp) +// tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI); + +// // dLdx +// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut]) +// // NDArray temp = c.ulike(); +// // applyActivation(c, params[9], params[10], params[11], temp); // actH(c) +// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // [bS, nIn](or[nOut]) + +// // dLdhI +// NDArray* dLdhII = dLdhI; +// if(dLdcI && !dLdhI) +// dLdhII = new NDArray(dLdcI->ulike()); +// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut]) + +// if(firstIter) { + +// // dLdcI +// if(dLdcI) +// dLdcI->assign(factor*tempC); // [bS, nOut](or[nOut]) + +// // dcIdWxi(dcdWxi) +// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// // dcIdWxf(dcdWxf) +// dcIdWxf.assign(dcdzf*dWxJacobian); +// // dcIdWxg(dcdWxg) +// dcIdWxg.assign(dcdzg*dWxJacobian); +// // dcIdWxo(dcdWxo) = 0 +// dcIdWxo.nullify(); + +// // dhIdWxi +// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// // dhIdWxf +// dhIdWxf.assign(dhdc*dcIdWxf); +// // dhIdWxg +// dhIdWxg.assign(dhdc*dcIdWxg); +// // dhIdWxo +// dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */); + +// // dcIdWri(dcdWri) +// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);; +// // dcIdWrf(dcdWrf) +// dcIdWrf.assign(dcdzf*dWrJacobian); +// // dcIdWrg(dcdWrg) +// dcIdWrg.assign(dcdzg*dWrJacobian); +// // dcIdWro(dcdWro) = 0 +// dcIdWro.nullify(); + +// // dhIdWri +// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); +// // dhIdWrf +// dhIdWrf.assign(dhdc*dcIdWrf); +// // dhIdWrg +// dhIdWrg.assign(dhdc*dcIdWrg); +// // dhIdWro +// dhIdWro.assign(dhdzo*dWrJacobian /*+ 0 */); + +// if(Wp && x->rankOf() == 1) { +// // dcIdWpi +// dcIdWpi->assign(dcdzi*(*cI)); // [nOut] * [nOut] +// // dcIdWpf +// dcIdWpf->assign(dcdzf*(*cI)); // [nOut] * [nOut] +// // dcIdWpo +// dcIdWpo->nullify(); // [nOut] + +// // dhdWpi +// dhIdWpi->assign(dhdc*(*dcIdWpi)); // [nOut] * [nOut] +// // dhdWpf +// dhIdWpf->assign(dhdc*(*dcIdWpf)); // [nOut] * [nOut] +// // dhdWpo +// dhIdWpo->assign(dhdzo*c /* +0*/); // [nOut] * [nOut] +// } +// else if(Wp) { +// // dcIdWpi +// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdWpf +// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdWpo +// dcIdWpo->nullify(); // [nOut] + +// // dhIdWpi +// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf +// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo +// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// } + +// if(b && x->rankOf() == 1) { +// // dcIdbi +// dcIdbi->assign(dcdzi); // [nOut] +// // dcIdbf +// dcIdbf->assign(dcdzf); // [nOut] +// // dcIdbg +// dcIdbg->assign(dcdzg); // [nOut] +// // dcIdbo +// dcIdbo->nullify(); // [nOut] + +// //dhIdbi +// dhIdbi->assign(dhdc*(*dcIdbi)); // [nOut] +// //dhIdbf +// dhIdbf->assign(dhdc*(*dcIdbf)); // [nOut] +// //dhIdbg +// dhIdbg->assign(dhdc*(*dcIdbg)); // [nOut] +// //dhIdbo +// dhIdbo->assign(dhdzo); // [nOut] + +// } +// else if(b) { +// // dcIdbi +// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbf +// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbg +// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbo +// dcIdbo->nullify(); // [nOut] + +// //dhIdbi +// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbf +// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbg +// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbo +// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] + +// } +// } +// else { + +// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT); +// NDArray tempO = mmul(dhdzo, WroT); + +// // dLdcI +// if(dLdcI) +// dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI)); + +// // dcIdWxi(dcdWxi) +// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dcIdWxf(dcdWxf) +// dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf); +// // dcIdWxg(dcdWxg) +// dcIdWxg.assign(dcdzg*dWxJacobian + tempIFE*dhIdWxg + tempC*dcIdWxg); +// // dcIdWxo(dcdWxo) +// dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo); + +// // dhIdWxi +// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dhIdWxf +// dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf); +// // dhIdWxg +// dhIdWxg.assign(dhdc*dcIdWxg + tempO*dhIdWxg); +// // dhIdWxo +// dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo); + +// // dcIdWri(dcdWri) +// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dcIdWrf(dcdWrf) +// dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf); +// // dcIdWrg(dcdWrg) +// dcIdWrg.assign(dcdzg*dWrJacobian + tempIFE*dhIdWrg + tempC*dcIdWrg); +// // dcIdWro(dcdWro) +// dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro); + +// // dhIdWri +// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); +// // dhIdWrf +// dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf); +// // dhIdWrg +// dhIdWrg.assign(dhdc*dcIdWrg + tempO*dhIdWrg); +// // dhIdWro +// dhIdWro.assign(dhdzo*dWrJacobian + dhdc*dcIdWro + tempO*dhIdWro); + +// if(Wp && x->rankOf() == 1) { +// // dcIdWpi +// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)); // [nOut] * [nOut] +// // dcIdWpf +// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)); // [nOut] * [nOut] +// // dcIdWpo +// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)); // [nOut] * [nOut] + +// // dhdWpi +// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * [nOut] +// // dhdWpf +// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * [nOut] +// // dhdWpo +// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // [nOut] * [nOut] +// } +// else if(Wp) { +// // dcIdWpi +// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dcIdWpf +// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dcIdWpo +// (/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] + +// // dhIdWpi +// (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf +// (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo +// (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// } + +// if(b && x->rankOf() == 1) { +// // dcIdbi +// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // [nOut] +// // dcIdbf +// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // [nOut] +// // dcIdbg +// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // [nOut] +// // dcIdbo +// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // [nOut] + +// //dhIdbi +// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // [nOut] +// //dhIdbf +// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // [nOut] +// //dhIdbg +// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // [nOut] +// //dhIdbo +// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // [nOut] + +// } +// else if(b) { +// // dcIdbi +// (dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbf +// (dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbg +// (dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] +// // dcIdbo +// (/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); // [bS, nOut]->reduce->[nOut] + +// //dhIdbi +// (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbf +// (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbg +// (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// //dhIdbo +// (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] + +// } +// } + +// const std::vector dimsToExclude = x->rankOf() == 1 ? std::vector({2}) : std::vector({2, 3}); + +// // dLdWxi, dLdWxf, dLdWxg, dLdWxo +// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude); + +// // dLdWri, dLdWrf, dLdWrg, dLdWro +// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude); + +// // dLdWpi, dLdWpf, dLdWpo +// if(Wp) { +// if(x->rankOf() == 1) { +// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] * [nOut] +// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] * [nOut] +// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] * [nOut] +// } +// else { +// // NDArray temp1 = (*dLdWp)({0, nOut}); +// // NDArray temp2 = (*dLdWp)({nOut, 2*nOut}); +// // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut}); +// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdWp)({0, nOut}).assign(dhIdWpi); +// (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf); +// (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo); +// } +// } + +// // dLdbi, dLdbf, dLdbg, dLdbo +// if(b) { +// if(x->rankOf() == 1) { +// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * [nOut] +// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * [nOut] +// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * [nOut] +// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * [nOut] +// } +// else { +// // NDArray temp1 = (*dLdb)({0, nOut}); +// // NDArray temp2 = (*dLdb)({nOut, 2*nOut}); +// // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut}); +// // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut}); +// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] +// (*dLdb)({0, nOut}).assign(dhIdbi); +// (*dLdb)({nOut, 2*nOut}).assign(dhIdbf); +// (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg); +// (*dLdb)({3*nOut, 4*nOut}).assign(dhIdbo); +// } +// } + +// //dhIdcI +// if(dLdcI) +// dhIdcI->assign(dhdc); + +// cI->assign(c); + +// if(dLdcI && !dLdhI) +// delete dLdhII; +// if(Wp) { +// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo; +// } +// if(b) { +// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo; +// } +// } diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index dfa9268b4..3a2d173b5 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -22,7 +22,6 @@ #define LIBND4J_LSTMLAYER_H #include -#include namespace sd { namespace ops { @@ -34,6 +33,20 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra const std::vector& params, NDArray* h, NDArray* c); +////////////////////////////////////////////////////////////////////////// +// this auxiliary ff should be running before backprop +void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* z, NDArray* a, NDArray* h, NDArray* c); + +////////////////////////////////////////////////////////////////////////// +void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdc, + const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); + + ////////////////////////////////////////////////////////////////////////// void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, @@ -42,71 +55,11 @@ void ND4J_EXPORT lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const ND NDArray* h, NDArray* hL, NDArray* cL); ////////////////////////////////////////////////////////////////////////// -static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { - - switch (opId) { - case 0: - (const_cast(x)).applyTransform(transform::Tanh, z); - break; - case 1: - (const_cast(x)).applyScalar(scalar::RELU, 0, z); - break; - case 2: - (const_cast(x)).applyTransform(transform::Sigmoid, z); - break; - case 3: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::Affine, z, &args); - break; - } - case 4: - (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); - break; - case 5: - helpers::thresholdRelu(x.getContext(), x, alpha, z); - break; - case 6: { - ExtraArguments args({ static_cast(alpha), static_cast(beta)}); - (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); - break; - } - case 7: - (const_cast(x)).applyTransform(transform::HardSigmoid, z); - break; - case 8: - (const_cast(x)).applyScalar(scalar::ELU, alpha, z); - break; - case 9: - (const_cast(x)).applyTransform(transform::SoftSign, z); - break; - case 10: - (const_cast(x)).applyTransform(transform::SoftPlus, z); - break; - default: - throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); - } -} - -////////////////////////////////////////////////////////////////////////// -static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { - - if(dataFormat == 0 || dataFormat == 3) - return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] - - if(dataFormat == 1) - return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] - - return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] -} - -////////////////////////////////////////////////////////////////////////// -static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { - - if(dataFormat == 0 || dataFormat == 3) - return t * bS + b; // TNS: shape [sL, bS, nIn] - - return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] -} +void ND4J_EXPORT lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, + const std::vector& params, const bool forward, + NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp); } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index e49165e78..2f02af11b 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1441,7 +1441,7 @@ namespace simdOps { } op_def static Z op(X d1) { - return d1; + return static_cast(d1); } }; @@ -2434,6 +2434,19 @@ namespace simdOps { } }; + template + class RELUDerivative { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static Z op(X d1, Y d2, Z *params) { + auto xt = static_cast(d1); + auto xf = static_cast(d2); + return xt > xf ? static_cast(1.f) : static_cast(0.f); + } + }; + template class SXELogitsSmoother { public: diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 4b5a24bb9..cee574dec 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -77,7 +77,7 @@ TEST_F(DeclarableOpsTests13, test_empty_range_1) { auto z = result.at(0); ASSERT_TRUE(z->isEmpty()); - + } TEST_F(DeclarableOpsTests13, test_empty_range_2) { @@ -262,7 +262,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { ASSERT_EQ(result.status(), Status::OK()); //result.at(0)->printBuffer("Output"); ASSERT_TRUE(exp1.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { @@ -286,7 +286,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { ASSERT_EQ(result.status(), Status::OK()); //result.at(0)->printBuffer("Output"); ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { @@ -312,7 +312,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { //exp.printBuffer("Expect"); //result.at(0)->printShapeInfo("Shape output"); ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { @@ -349,7 +349,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { //result.at(2)->printBuffer("Symmetrized2"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); ASSERT_TRUE(exp.equalsTo(result.at(2))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { @@ -369,7 +369,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { @@ -398,7 +398,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); ASSERT_TRUE(exp4.equalsTo(res)); - + } TEST_F(DeclarableOpsTests13, CellContains_test_1) { @@ -420,7 +420,7 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); //ASSERT_TRUE(exp.equalsTo(result.at(0))); - + } //////////////////////////////////////////////////////////////////// @@ -712,7 +712,7 @@ TEST_F(DeclarableOpsTests13, rshift_bits_2) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { @@ -1109,6 +1109,7 @@ TEST_F(DeclarableOpsTests13, mergeavg_bp_1) { } } + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, lstmLayer_1) { @@ -1200,7 +1201,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = false; // peephole connections are absent const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = true; // do not return output at last time step + const auto retLastH = true; // return output at last time step const auto retLastC = true; // return cells state at last time step const double cellClip = 0; // do not apply clipping @@ -1398,7 +1399,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + } /////////////////////////////////////////////////////////////////// @@ -1640,7 +1641,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_7) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1718,7 +1719,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1805,7 +1806,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1890,7 +1891,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -1970,7 +1971,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - + #endif } @@ -2061,10 +2062,528 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { ASSERT_TRUE(expCL.isSameShape(cL)); ASSERT_TRUE(expCL.equalsTo(cL)); - #endif } +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::SUM, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { + + const int sL = 4; + const int bS = 3; + const int nIn = 3; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 2; // bidirectional sum + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = false; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 3; // [sL, bS, nIn] + const int directionMode = 4; // bidirectional extra output dim + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = false; // output at last time step + const auto retLastC = false; // cells state at last time step + + const double cellClip = 0.5; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + + ASSERT_TRUE(isGradCorrect); +} //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests13, batchnorm_test1) { @@ -2091,7 +2610,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test1) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// @@ -2233,7 +2752,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test6) { ASSERT_TRUE(expected.isSameShapeStrict(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////// @@ -2345,7 +2864,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_test9) { ASSERT_TRUE(expected.isSameShape(*output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2387,7 +2906,7 @@ TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + } @@ -2642,7 +3161,7 @@ return; ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); ASSERT_TRUE(expdLdB.equalsTo(dLdB)); - + } //////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 5636b2e29..166ba058f 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -844,5 +844,78 @@ TEST_F(PlaygroundTests, my) { printf("time: %i \n", time); } +/////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, lstmLayerCellBp_1) { + + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + // const int nIn = 8; + // const int nOut = 6; + + const float cellClip = 1.1; // clipping value + const Nd4jLong gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const Nd4jLong cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const Nd4jLong outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdc('c', {bS, nOut}, sd::DataType::DOUBLE); + + // NDArray x ('c', {nIn}, sd::DataType::DOUBLE); + // NDArray hI('c', {nOut}, sd::DataType::DOUBLE); + // NDArray cI('c', {nOut}, sd::DataType::DOUBLE); + // NDArray dLdh('c', {nOut}, sd::DataType::DOUBLE); + // NDArray dLdc('c', {nOut}, sd::DataType::DOUBLE); + + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b ('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + + x.linspace(-4,1); + hI.linspace(-2.5,0.5); + cI.linspace(-3,0.5); + Wx.linspace(0,0.1); + Wr.linspace(3,-0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + // x.assign(1.); + // hI.assign(2.); + // cI.assign(3.); + // Wx.assign(0.5); + // Wr.assign(0.5); + // Wp.assign(0.75); + // b.assign(0.7); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {gateAct, cellAct, outAct}; + + // std::vector bArgs = {false, false}; + // const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, iArgs, bArgs); + + std::vector bArgs = {true, true}; + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayerCell opFF; + sd::ops::lstmLayerCellBp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, true, true, true}); +} + + */ + +