diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index f5fd1f3df..9ca18a82b 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -50,10 +50,9 @@ class ND4J_EXPORT GradCheck { * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM - * outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent */ static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector& outArrsFFIdx = {}); + const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM); }; diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index f3daa798c..12ecab75f 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector& ////////////////////////////////////////////////////////////////////////// bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss, const std::vector& outArrsFFIdx) { + const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss) { const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP @@ -82,23 +82,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons int numOutArrs = outArrsFF.size(); double scorePlus = 0.; - if(!outArrsFFIdx.empty()) { - for(const auto& k : outArrsFFIdx) { // loop through independent output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } - } - else { - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); } // subtract epsilon, feed forward @@ -106,23 +95,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons outArrsFF = opFF.execute(argsHolderFF); double scoreMinus = 0.; - if(!outArrsFFIdx.empty()) { - for(const auto& k : outArrsFFIdx) { // loop through independent output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } - } - else { - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); } // restore initial element value diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp index c6cd2e8f1..dee9a7c88 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,7 +16,7 @@ ******************************************************************************/ // -// created by Yurii Shyrma on 15.02.2018 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -30,83 +31,157 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time x bS x iS] - auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x nU] + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU] - auto b = INPUT_VARIABLE(4); // biases, [3*nU] + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x nU], that is per each time step + auto h = OUTPUT_VARIABLE(0); // cell outputs [time, bS, nOut], that is per each time step - const int rank = x->rankOf(); // = 3 - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int iS = x->sizeAt(2); - const int nU = h0->sizeAt(1); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); - const std::vector h0CorrectShape = {bS, nU}; - const std::vector wxCorrectShape = {iS, 3*nU}; - const std::vector whCorrectShape = {nU, 3*nU}; - const std::vector bCorrectShape = {3*nU}; + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; - REQUIRE_TRUE(h0->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - helpers::gruTimeLoop(block.launchContext(), x, h0, Wx, Wh, b, h); + helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h); return Status::OK(); } +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(gru) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - DECLARE_TYPES(gru) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - +////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(gru) { - const auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - const auto h0ShapeInfo = inputShape->at(1); // initial cell output [bS x numUnits], that is at time step t=0 - const auto WxShapeInfo = inputShape->at(2); // input-to-hidden weights, [inSize x 3*numUnits] - const auto WhShapeInfo = inputShape->at(3); // hidden-to-hidden weights, [numUnits x 3*numUnits] - const auto bShapeInfo = inputShape->at(4); // biases, [3*numUnits] - const int rank = shape::rank(xShapeInfo); // = 3 - const auto time = xShapeInfo[1]; - const auto bS = xShapeInfo[2]; - const auto inSize = xShapeInfo[3]; - const auto numUnits = h0ShapeInfo[2]; + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - const std::vector h0CorrectShape = {bS, numUnits}; - const std::vector wxCorrectShape = {inSize, 3*numUnits}; - const std::vector whCorrectShape = {numUnits, 3*numUnits}; - const std::vector bCorrectShape = {3*numUnits}; + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; - // evaluate output shapeInfo - Nd4jLong *hShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - hShapeInfo[0] = rank; - hShapeInfo[1] = time; - hShapeInfo[2] = bS; - hShapeInfo[3] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo)); + auto* hShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(hI->dataType(), hI->ordering(), {time, bS, nOut}); return SHAPELIST(hShapeInfo); } +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(gru_bp, 6, 5, false, 0, 0) { + + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn] + auto dLdhI = OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut] + auto dLdWx = OUTPUT_NULLIFIED(2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut] + auto dLdWh = OUTPUT_NULLIFIED(3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut] + auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + + helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, dLdhI, dLdWx, dLdWh, dLdb); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(gru_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(gru_bp) { + + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + + Nd4jLong* dLdxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), x->getShapeInfo()); + Nd4jLong* dLdhIShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), hI->getShapeInfo()); + Nd4jLong* dLdWxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wx->getShapeInfo()); + Nd4jLong* dLdWhShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wh->getShapeInfo()); + Nd4jLong* dLdbShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), b->getShapeInfo()); + + return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo); +} + } } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp index 204a1ca63..037f09736 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp @@ -161,7 +161,7 @@ CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) { REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str()); REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); + helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp index 8637fe990..871291165 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp @@ -727,12 +727,10 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0})); } - // FIXME looks like sum (directionMode == 2) is impossible for backprop if(dLdh) { if(directionMode == 2) { // sum - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !"); - // dLdhFwd = dLdh; - // dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content + dLdhFwd = dLdh; + dLdhBwd = dLdh; } else if(directionMode == 3) { // concat dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0})); @@ -744,21 +742,20 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { } } + NDArray dLdxBwd = dLdx->ulike(); - + // FIXME - following two calls are independent and may run in different streams helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd); - NDArray dLdxBwd = dLdx->ulike(); helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd); *dLdx += dLdxBwd; delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd; - delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; + delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd; delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd; - if(dLdhFwd != dLdh) - delete dLdhFwd; + if(!(dLdh && directionMode == 2)) { delete dLdhFwd; delete dLdhBwd; } } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp index 46f32e399..4f24219bd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp @@ -293,7 +293,7 @@ CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) { helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); - helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); + helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index dd219867f..aeeae24c4 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -345,6 +345,10 @@ namespace ops { DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0); #endif + #if NOT_EXCLUDED(OP_gru) + DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0); + #endif + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operation "static RNN time sequences" with peep hole connections: diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp deleted file mode 100644 index b00036b81..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp +++ /dev/null @@ -1,421 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black -// - -// implementation of gated Recurrent Unit cell -// (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - - -#include -#include -#include -#include - -namespace sd { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, iS], iS - input size - // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - // W RU weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - - //Outputs: - // r Reset gate output [bS, nU] - // u Update gate output [bS, nU] - // c Cell gate output [bS, nU] - // h current cell output [bS, nU] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh, *c); - - NDArray temp = 1.f - *c * *c; - - // cell output - h->assign(*u * *hLast + (1.f - *u) * *c); - - - /***************************************************************************************/ - /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ - /***************************************************************************************/ -/* - //Concat inputs: x + hLast : [bs, iS + nU] - NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU] - helpers::concat(context, {const_cast(x), const_cast(hLast)}, xhConcat, {1}); - - //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u) - auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU] - // m += *bru; - - m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz) - - r->assign(m({0,0, 0, nU})); - u->assign(m({0,0, nU, 2*nU})); - - // hLast = hLast * r - xhConcat({0,0, iS, iS+nU}) *= *r; - - //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c) - MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c - *c += *bc; - c->applyTransform(transform::Tanh); - - //Output: h = (1-u).*c + u .* hPrev - //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast(h)->assign(&hResult); - u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast - auto temp = (1.0f - *u); - temp *= (*c); - (*h) += temp; -*/ -} - -////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // x input [time, bS, iS] - // hLast initial cell output (at time step = 0) [bS, nU] - // Wx input-to-hidden weights, [iS, 3*nU] - // Wh hidden-to-hidden weights, [nU, 3*nU] - // b biases, [3*nU] - - // h is cell outputs at each time step [time, bS, nU] - - const int time = x->sizeAt(0); - - NDArray ht_1(*hLast); - - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - - // helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); - // ht_1.assign(ht); - } -} - -////////////////////////////////////////////////////////////////////////// -void gruCellBP(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] -} - -// ////////////////////////////////////////////////////////////////////////// -// FIXME - gruTimeLoopBP is not correct -// template -// void gruTimeLoopBP(const std::vector*>& inArrs, const std::vector*>& outArrs) { - -// NDArray* x = inArrs[0]; // input [time, bS, iS] -// NDArray* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1 -// NDArray* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU] -// NDArray* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU] -// NDArray* b = inArrs[4]; // biases, [3*nU] -// NDArray* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next - -// NDArray* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon -// NDArray* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU] -// NDArray* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU] -// NDArray* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU] -// NDArray* dLdb = outArrs[4]; // gradient wrt b, [3*nU] - -// const Nd4jLong time = x->sizeAt(0); -// const Nd4jLong bS = x->sizeAt(1); -// const Nd4jLong iS = x->sizeAt(2); -// const Nd4jLong nU = hi->sizeAt(1); - -// NDArray h(hi->ordering(), {time, bS, nU}); // feed forward output - -// // first step, time = 0, feed forward -// NDArray x0 = (*x)({{0,1}, {}, {}}); -// NDArray hLast = h({{0,1}, {}, {}}); -// helpers::gruCell({&x0, hi, Wx, Wh, b}, &hLast); - -// // first step, time = 0, back prop -// NDArray dLdx0 = (*dLdx)({{0,1}, {}, {}}); -// NDArray dLdhLast = (*dLdh)({{0,1}, {}, {}}); -// helpers::gruCellBP({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb}); - -// // loop through the rest time steps -// for (Nd4jLong t = time-1; t > 0; --t) { -// for (Nd4jLong t = 1; t < time; ++t) { - -// NDArray xt = (*x)({{t,t+1}, {}, {}}); -// NDArray ht = h({{t,t+1}, {}, {}}); -// NDArray ht_1 = h({{t-1,t}, {}, {}}); -// NDArray dLdxt = (*dLdx)({{t,t+1}, {}, {}}); -// NDArray dLdht = (*dLdh)({{t,t+1}, {}, {}}); - -// NDArray dLdWxt_1 = dLdWx; -// NDArray dLdWht_1 = dLdWh; -// NDArray dLdbt_1 = dLdb; - -// // feed forward, calculation of ht -// helpers::gruCell({&xt, &ht_1, Wx, Wh, b}, &ht); - -// // back prop -// helpers::gruCellBP({&xt, &ht_1, Wx, Wh, b, &dLdht, &dLdWxt_1, &dLdWht_1, &dLdbt_1}, {&dLdxt, nullptr, dLdWx, dLdWh, dLdb}); -// } -// } - - -} -} -} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu b/libnd4j/include/ops/declarable/helpers/cuda/gru.cu deleted file mode 100644 index bd4e878e3..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu +++ /dev/null @@ -1,365 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018 -// - -// implementation of gated Recurrent Unit cell -// (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - - -#include -#include -#include -#include - -namespace sd { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, iS], iS - input size - // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - // W RU weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - - //Outputs: - // r Reset gate output [bS, nU] - // u Update gate output [bS, nU] - // c Cell gate output [bS, nU] - // h current cell output [bS, nU] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh, *c); - - NDArray temp = 1.f - *c * *c; - - // cell output - h->assign(*u * *hLast + (1.f - *u) * *c); - - - /***************************************************************************************/ - /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ - /***************************************************************************************/ -/* - //Concat inputs: x + hLast : [bs, iS + nU] - NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU] - helpers::concat(context, {const_cast(x), const_cast(hLast)}, xhConcat, {1}); - - //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u) - auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU] - // m += *bru; - - m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz) - - r->assign(m({0,0, 0, nU})); - u->assign(m({0,0, nU, 2*nU})); - - // hLast = hLast * r - xhConcat({0,0, iS, iS+nU}) *= *r; - - //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c) - MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c - *c += *bc; - c->applyTransform(transform::Tanh); - - //Output: h = (1-u).*c + u .* hPrev - //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast(h)->assign(&hResult); - u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast - auto temp = (1.0f - *u); - temp *= (*c); - (*h) += temp; -*/ -} - -////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // x input [time, bS, iS] - // hLast initial cell output (at time step = 0) [bS, nU] - // Wx input-to-hidden weights, [iS, 3*nU] - // Wh hidden-to-hidden weights, [nU, 3*nU] - // b biases, [3*nU] - - // h is cell outputs at each time step [time, bS, nU] - - const int time = x->sizeAt(0); - - NDArray ht_1(*hLast); - - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - - // helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); - // ht_1.assign(ht); - } -} - -////////////////////////////////////////////////////////////////////////// -void gruCellBP(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] -} - - -} -} -} - diff --git a/libnd4j/include/ops/declarable/helpers/gru.h b/libnd4j/include/ops/declarable/helpers/gru.h index 3fecfa71b..9e98e4046 100644 --- a/libnd4j/include/ops/declarable/helpers/gru.h +++ b/libnd4j/include/ops/declarable/helpers/gru.h @@ -31,10 +31,26 @@ namespace helpers { const NDArray* bru, const NDArray* bc, NDArray* r, NDArray* u, NDArray* c, NDArray* h); + void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, const NDArray* b, + NDArray* gates, NDArray* h); + void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); - void gruCellBP(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc); + void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hLast, + const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, + const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, + NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc); + void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); + + void gruTimeLoopBp(sd::LaunchContext * context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); } } } diff --git a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp new file mode 100644 index 000000000..277188428 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp @@ -0,0 +1,546 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * ThnIn program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which nIn available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permnInsions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black +// + +// implementation of gated Recurrent Unit cell +// (cf. https://arxiv.org/abs/1406.1078). +// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio +// "Learning Phrase Representations using RNN Encoder-Decoder for StatnIntical Machine Translation" + + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* W, const NDArray* Wc, + const NDArray* b, const NDArray* bc, + NDArray* r, NDArray* u, NDArray* c, NDArray* h) { + + //Inputs: + // x input [bS, nIn], nIn - input size + // hI previous cell output [bS, nOut], that is at previous time step t-1, nOut - number of units + // W RU weights - [nIn+nOut, 2*nOut] - reset and update gates + // Wc C weights - [nIn+nOut, nOut] - cell gate + // b r and u biases, [2*nOut] - reset and update gates + // bc c biases, [nOut] - cell gate + + //Outputs: + // r Reset gate output [bS, nOut] + // u Update gate output [bS, nOut] + // c Cell gate output [bS, nOut] + // h current cell output [bS, nOut] + + /***************************************************************************************/ + /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ + /** however it is more math-friendly and convenient for backprop formulas derivation) **/ + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray Wrx = (*W)({0,nIn, 0,nOut}); // [nIn, nOut] + NDArray Wux = (*W)({0,nIn, nOut,2*nOut}); // [nIn, nOut] + NDArray Wrh = (*W)({nIn,nIn+nOut, 0,nOut}); // [nOut, nOut] + NDArray Wuh = (*W)({nIn,nIn+nOut, nOut,2*nOut}); // [nOut, nOut] + + NDArray Wcx = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nOut] + NDArray Wch = (*Wc)({nIn,nIn+nOut, 0,0}); // updates cell weights [nOut, nOut] + + NDArray br = (*b)({0, nOut}); // [nOut] + NDArray bu = (*b)({nOut, 2*nOut}); // [nOut] + + // × means matrix multipication + // * means element-wise product or so called Hadamard product + + // reset gate + r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + r->applyTransform(transform::Sigmoid, *r); + + // update gate + u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + u->applyTransform(transform::Sigmoid, *u); + + // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) + c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + c->applyTransform(transform::Tanh, *c); + + // cell output + h->assign(*u * *hI + (1.f - *u) * *c); +} + +////////////////////////////////////////////////////////////////////////// +void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, + NDArray* gates, NDArray* h) { + + //Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that is at previous time step t-1 + // Wx weights for x - [nIn, 3*nOut] + // Wh weights for h - [nOut, 3*nOut] + // b biases [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + //Outputs: + // gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + cell gate [bS, nOut] + // h current cell output [bS, nOut] + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray temp = gates->ulike(); + MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut] + temp += *b; + + MmulHelper::mmul(hI, Wh, gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut] + + NDArray ru = (*gates)({0,0, 0,2*nOut}); // [bS, 2*nOut] + + NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] + NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + // reset and update gates + ru += temp({0,0, 0,2*nOut}); + ru.applyTransform(transform::Sigmoid, ru); + + // cell gate + c.assign(c*r + temp({0,0, 2*nOut, 3*nOut})); + c.applyTransform(transform::Tanh, c); + + // cell output + h->assign(u * *hI + (1.f - u) * c); +} + +////////////////////////////////////////////////////////////////////////// +void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { + + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + + // h cell outputs at each time step [sL, bS, nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(h->ordering(), {bS, 3*nOut}, h->dataType(), context); + + auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + auto hSet = h->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + + // time loop + for (int t = 0; t < sL; ++t) + gruCell(context, xSet.at(t), t == 0 ? hI : hSet.at(t-1), Wx, Wh, b, &gates, hSet.at(t)); +} + +////////////////////////////////////////////////////////////////////////// +void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hLast, + const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, + const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, + NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc) { + + //Inputs: + // x input [bS, iS] + // hLast previous cell output [bS, nU], that is at previous time step t-1 + // W weights - [iS+nU, 2*nU] - reset and update gates + // Wc C weights - [iS+nU, nU] - cell gate + // b r and u biases, [2*nU] - reset and update gates + // bc c biases, [nU] - cell gate + // dLdr gradient wrt reset gate, [bS, nU] + // dLdu gradient wrt update gate, [bS, nU] + // dLdc gradient wrt cell state, [bS, nU] + // dLdh gradient wrt current cell output, [bS, nU] + + //Outputs: + // dLdx gradient wrt x, [bS, iS], + // dLdhLast gradient wrt hLast, [bS, nU] + // dLdW gradient wrt W, [iS+nU, 2*nU] + // dLdWc gradient wrt Wc, [iS+nU, nU] + // dLdb gradient wrt bru [2*nU] + // dLdbc gradient wrt bc [nU] + + // * means element-wise product or so called Hadamard product + // × means matrix multiplication + + /************************************************************************************************/ + /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ + /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ + + const int bS = x->sizeAt(0); + const int iS = x->sizeAt(1); + const int nU = hLast->sizeAt(1); + + NDArray xT = x->transpose(); // [iS, bS] + NDArray hLastT = hLast->transpose(); // [nU, bS] + + NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] + NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] + NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] + NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] + + NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] + NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] + + NDArray br = (*b)({0, nU}); // [nU] + NDArray bu = (*b)({nU, 2*nU}); // [nU] + + NDArray WrxT = Wrx.transpose(); // [nU, iS] + NDArray WuxT = Wux.transpose(); // [nU, iS] + NDArray WrhT = Wrh.transpose(); // [nU, nU] + NDArray WuhT = Wuh.transpose(); // [nU, nU] + + NDArray WcxT = Wcx.transpose(); // [nU, iS] + NDArray WchT = Wch.transpose(); // [nU, nU] + + NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] + NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] + NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] + NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] + + NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] + NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] + + NDArray dLdbr = (*dLdb)({0, nU}); // [nU] + NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] + + + // ***** feed forward step ***** // + + // reset gate + NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + r.applyTransform(transform::Sigmoid, r); + + // update gate + NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + u.applyTransform(transform::Sigmoid, u); + + // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) + NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + c.applyTransform(transform::Tanh, c); + + // h = (1 - u) * c + u * hPrev + + + // ***** back prop step ***** // + + // notations: + // Zr = x × Wrx + hLast × Wrh + br + // Zu = x × Wux + hLast × Wuh + bu + // Sr = sigmoid(Zr) + // Su = sigmoid(Zu) + // Zc = x × Wcx + (r * hlast) × Wch + bc + + + // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx + // = dLdx_u + dLdx_c + // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT + // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 + // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT + // dZcdr = (... * hLast) × WchT + // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT + // drdx = drdZr * dZrdx + // dZrdx = ... × WrxT + // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT + // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT + + + // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast + // = dLdhLast_h + dLdhLast_u + dLdhLast_c + // dLdhLast_h = dLdh * dhdhLas = dLdh * u + // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT + // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = + // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = + // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 + // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT + // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT + // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = + // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT + + + // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx + // dZrdWrx = xT × ... + // finally dLdWrx = xT × (dLdr * drdZr) + + + // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh + // dZrdWrh = hLastT × ... + // finally dLdWrh = hLastT × (dLdr * drdZr) + + + // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux + // dZudWux = xT × ... + // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) + + + // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh + // dZudWuh = hLastT × ... + // finally dLdWuh = hLastT × (dLdu * dudZu) + + + // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx + // dZcdWcx = xT × ... + // finally dLdWcx = xT × (dLdc * dcdZc) + + + // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch + // dZcdWch = (r*hLast)^T × ... + // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) + + + // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = + // = dLdr * drdZr * dZrdbr + // dZrdbr = 1 + // finally dLdbr = dLdr * drdZr + + + // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu + // dZudbu = 1 + // finally dLdbu = dLdu * dudZu + + + // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc + // dZcdbc = 1 + // finally dLdbc = dLdc * dcdZc + + NDArray dhdc = 1.f - u; // [bS, nU] + NDArray dhdu = *hLast - c; // [bS, nU] + NDArray dudZu = u * dhdc; // [bS, nU] + NDArray drdZr = r * (1.f - r); // [bS, nU] + NDArray dcdZc = 1.f - c * c; // [bS, nU] + NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] + NDArray dLdZu = *dLdu * dudZu; // [bS, nU] + NDArray dLdZr = *dLdr * drdZr; // [bS, nU] + + // NDArray dLdc = *dLdh * dhdc; // [bS, nU] + // NDArray dLdu = *dLdh * dhdu; // [bS, nU] + // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] + + dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] + + dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] + + dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] + dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] + + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] +} + + +////////////////////////////////////////////////////////////////////////// +void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + + //Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that nIn at previous time step t-1 + // Wx input-to-hidden weights - [nIn, 3*nOut] + // Wh hidden-to-hidden weights - [nOut, 3*nOut] + // b biases, [3*nOut] - reset and update gates + // dLdh gradient vs. ff output, [bS, nOut] + + //Outputs: + // dLdx gradient vs. x, [bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + // * means element-wnIne product or so called Hadamard product + // × means matrix multiplication + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + // dLdhI += dLdh; [bS, nOut] + + + // dhdc = 1 - u [bS, nOut] + // dhdu = -c + hI [bS, nOut] + + // dcdzc = 1 - c*c; [bS, nOut] + // dudzu = u*(1-u) [bS, nOut] + // drdzr = r(1-r) [bS, nOut] + + // dzcdr = (...*hI × WhcT) [bS, nOut] + + // dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut] + // dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut] + // dLdzc = dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut] + + // dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, nIn] + ... = [bS, nIn] + + // dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, nOut] + ... = [bS, nOut] + + // dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxc = xT × dLdzc [nIn, bS] x [bS, nOut] = [nIn, nOut] + + // dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWhu = xT × dLdzu [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWhc = (r*hI)T × dLdzc [nOut, bS] x [bS, nOut] = [nOut, nOut] + + // dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + const int nOut = hI->sizeAt(1); + + NDArray dLdz = gates->ulike(); // [bS, 3*nOut] + + NDArray dLdzru = dLdz({0,0, 0,2*nOut}); // [bS, 2*nOut] + + NDArray dLdzr = dLdz({0,0, 0,nOut}); // [bS, nOut] + NDArray dLdzu = dLdz({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray dLdzc = dLdz({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] + NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + NDArray WhcT = (*Wh)({0,0, 2*nOut,3*nOut}).transpose(); + + if(dLdh) + *dLdhI += *dLdh; + + NDArray temp1 = 1 - u; // [bS, nOut] + + // dLdzc + dLdzc.assign(*dLdhI * temp1 * (1-c*c)); // [bS, nOut] + + // dLdzu + dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut] + + // dLdzr + NDArray temp2 = dLdzc * (*hI) * r *(1-r); + MmulHelper::mmul(&temp2, &WhcT, &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut] + + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn] + + // dLdWx + *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut] + + // dLdb + *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut]; + + dLdzc *= r; + + // dLdhI + NDArray WhT = Wh->transpose(); + dLdhI->assign(*dLdhI*u + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut] + + // dLdWr + *dLdWh += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut] +} + + +////////////////////////////////////////////////////////////////////////// +void gruTimeLoopBp(sd::LaunchContext * context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + // dLdh gradient vs. ff output, [sL, bS, nOut] + + // dLdx gradient vs. x, [sL, bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(x->ordering(), {sL, bS, 3*nOut}, dLdh->dataType(), x->getContext()); + NDArray h(x->ordering(), {sL+1, bS, nOut}, dLdh->dataType(), x->getContext()); + + auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + auto dLdhSet = dLdh->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto hSet = h.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + + hSet.at(0)->assign(hI); + + // forward time loop + for (int t = 0; t < sL; ++t) + gruCell(context, xSet.at(t), hSet.at(t), Wx, Wh, b, gatesSet.at(t), hSet.at(t+1)); + + // backward time loop + for (int t = sL-1; t >= 0; --t) + gruCellBp(context, xSet.at(t), hSet.at(t), Wx, Wh, b, dLdhSet.at(t), gatesSet.at(t), + dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 9fce17c4b..bffd13128 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -189,54 +189,6 @@ static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] } -////////////////////////////////////////////////////////////////////////// -// x{M,K} x y{K,N} = z{M,N}, dzdy{K,N,M,N} - Jacobian derivative -> if x.rankOf() == 2 -// x{K} x y{K,N} = z{N}, dzdy{K,N,N} - Jacobian derivative -> if x.rankOf() == 1 -static NDArray mmulJacobianWeightsDeriv(const int nOut, const NDArray& x) { - - std::vector outShape = x.rankOf() == 1 ? std::vector({x.sizeAt(0), nOut, nOut}) : std::vector({x.sizeAt(1), nOut, x.sizeAt(0), nOut}); - - NDArray dzdy(x.ordering(), outShape, x.dataType(), x.getContext()); - - if(x.rankOf() == 1) { - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - if(i1 == i2) - dzdy.p(i0,i1,i2, x.e(i0)); - else - dzdy.p(i0,i1,i2, 0); - } - } - } - }; - - samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); - } - else { - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (auto i3 = 0; i3 < dzdy.sizeAt(3); ++i3) { - if(i1 == i3) - dzdy.p(i0,i1,i2,i3, x.e(i2,i0)); - else - dzdy.p(i0,i1,i2,i3, 0); - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); - } - - return dzdy; -} ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, @@ -245,25 +197,25 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, NDArray* h, NDArray* c) { // * -> means element-wise multiplication - // ^ -> means matrix multiplication + // × -> means matrix multiplication /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // it = σ(Wxi ^ xt + Wri ^ ht-1 + bi) - // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + bf) - // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // it = σ(Wxi × xt + Wri × ht-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't - // ot = σ(Wxo ^ xt + Wro ^ ht-1 + bo) + // ot = σ(Wxo × xt + Wro × ht-1 + bo) // ht = ot * tanh(ct) // equations (peephole connections are present) - // it = σ(Wxi ^ xt + Wri ^ ht-1 + Wpi * ct-1 + bi) - // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + Wpf * ct-1 + bf) - // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't - // ot = σ(Wxo ^ xt + Wro ^ ht-1 + Wpo * ct + bo) + // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) // ht = ot * tanh(ct) @@ -399,7 +351,7 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, ////////////////////////////////////////////////////////////////////////// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdc, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { @@ -407,10 +359,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // zi = x ^ Wxi + hI ^ Wri + bi - // zf = x ^ Wxf + hI ^ Wrf + bf - // zg = x ^ Wxg + hI ^ Wrg + bg - // zo = x ^ Wxo + hI ^ Wro + bo + // zi = x × Wxi + hI × Wri + bi + // zf = x × Wxf + hI × Wrf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + bo // i = act(zi) // f = act(zf) // g = actC(zg) @@ -419,10 +371,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // h = o * actH(c) // equations (peephole connections are present) - // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi - // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf - // zg = x ^ Wxg + hI ^ Wrg + bg - // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo + // zi = x × Wxi + hI × Wri + cI * Wpi + bi + // zf = x × Wxf + hI × Wrf + cI * Wpf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + c * Wpo + bo // i = act(zi) // f = act(zf) // g = actC(zg) @@ -449,18 +401,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // params[11] - beta value for output activation // INPUTS: - // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdc - loss derivative with respect to c, [bS, nOut] or [nOut] if seqLen != nullptr - // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr + // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] // OUTPUTS: // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr @@ -485,19 +438,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) - // dLdx = dLdzi^WxiT + dLdzf^WxfT + dLdzg^WxgT + dLdzo^WxoT, [bS, nIn] - // dLdhI = dLdzi^WriT + dLdzf^WrfT + dLdzg^WrgT + dLdzo^WroT, [bS, nOut] + // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] + // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] // dLdcI = dLdcI*dcdcI, [bS, nOut] - // dLdWxi = xT^dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxf = xT^dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxg = xT^dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxo = xT^dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWri = hIT^dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrf = hIT^dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrg = hIT^dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWro = hIT^dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] @@ -563,10 +516,12 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con if(dLdh) *dLdhI += *dLdh; - if(dLdc) - *dLdcI += *dLdc; - else - *dLdcI += *dLdhI * dhdc; + if(dLdhL) + *dLdhI += *dLdhL; + if(dLdcL) + *dLdcI += *dLdcL; + + *dLdcI += *dLdhI * dhdc; dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) @@ -662,25 +617,27 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const std::vector shapeOut = {bS, nOut}; + const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); + auto h0 = const_cast(hI); if(!hI) { - h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); h0->nullify(); } auto c0 = const_cast(cI); if(!cI) { - c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); c0->nullify(); } auto ct = cL; if(!cL) - ct = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); auto ht = hL; if(!h && !hL) - ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); // create sets of required (depends on seqLen presence) sub-arrays std::vector dims; @@ -989,17 +946,19 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); const int nOut = Wx->sizeAt(-1) / 4; + const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); + auto dLdh0 = dLdhI; if(!hI) - dLdh0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically auto dLdc0 = dLdcI; if(!cI) - dLdc0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically - NDArray z(x->ordering(), {sL, bS, 4*nOut}, x->dataType(), x->getContext()); + NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext()); NDArray a = z.ulike(); - NDArray h(x->ordering(), {sL+1, bS, nOut}, x->dataType(), x->getContext()); + NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext()); NDArray c = h.ulike(); // create sets of required (depends on seqLen presence) sub-arrays @@ -1041,9 +1000,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(dLdh) dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] - if(!dLdh && dLdhL) + if(dLdhL) dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(!dLdh && !dLdhL) + if(dLdcL) dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] } @@ -1054,13 +1013,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // seqLen is absent if(hI) - h({0,1, 0,0, 0,0}).assign(hI); + hSet->at(0)->assign(hI); else - h({0,1, 0,0, 0,0}).nullify(); + hSet->at(0)->nullify(); if(cI) - c({0,1, 0,0, 0,0}).assign(cI); + cSet->at(0)->assign(cI); else - c({0,1, 0,0, 0,0}).nullify(); + cSet->at(0)->nullify(); // ff for (int t = 0; t < sL; ++t) @@ -1068,9 +1027,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = sL-1; t >= 0; --t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == sL-1 ? dLdhL : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-1 ? dLdcL : nullptr); - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1086,13 +1046,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - h({0,1, e,e+1, 0,0}).assign(hISet->at(e)); + hSet->at(e)->assign(hISet->at(e)); else - h({0,1, e,e+1, 0,0}).nullify(); + hSet->at(e)->nullify(); if(cI) - c({0,1, e,e+1, 0,0}).assign(cISet->at(e)); + cSet->at(e)->assign(cISet->at(e)); else - c({0,1, e,e+1, 0,0}).nullify(); + cSet->at(e)->nullify(); // ff for (int t = 0; t < limit; ++t) @@ -1102,9 +1062,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = limit-1; t >= 0; --t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == limit-1 && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == limit-1 ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1119,13 +1080,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // backward or bidirectional, seqLen is absent if(hI) - h({sL,sL+1, 0,0, 0,0}).assign(hI); + hSet->at(sL)->assign(hI); else - h({sL,sL+1, 0,0, 0,0}).nullify(); + hSet->at(sL)->nullify(); if(cI) - c({sL,sL+1, 0,0, 0,0}).assign(cI); + cSet->at(sL)->assign(cI); else - c({sL,sL+1, 0,0, 0,0}).nullify(); + cSet->at(sL)->nullify(); // ff for (int t = sL-1; t >= 0; --t) @@ -1133,9 +1094,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < sL; ++t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == 0 ? dLdhL : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcL : nullptr); - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1151,13 +1113,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - h({sL,sL+1, e,e+1, 0,0}).assign(hISet->at(e)); + hSet->at(sL*bS + e)->assign(hISet->at(e)); else - h({sL,sL+1, e,e+1, 0,0}).nullify(); + hSet->at(sL*bS + e)->nullify(); if(cI) - c({sL,sL+1, e,e+1, 0,0}).assign(cISet->at(e)); + cSet->at(sL*bS + e)->assign(cISet->at(e)); else - c({sL,sL+1, e,e+1, 0,0}).nullify(); + cSet->at(sL*bS + e)->nullify(); // ff for (int t = sL - 1; t >= sL-limit; --t) @@ -1167,9 +1129,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = sL-limit; t < sL; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == sL-limit && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-limit ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1206,9 +1169,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < limit; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == 0 && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1248,10 +1212,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // /** the objective is to provide math-readable code **/ // // equations (no peephole connections) -// // zi = x ^ Wxi + hI ^ Wri + bi -// // zf = x ^ Wxf + hI ^ Wrf + bf -// // zg = x ^ Wxg + hI ^ Wrg + bg -// // zo = x ^ Wxo + hI ^ Wro + bo +// // zi = x × Wxi + hI × Wri + bi +// // zf = x × Wxf + hI × Wrf + bf +// // zg = x × Wxg + hI × Wrg + bg +// // zo = x × Wxo + hI × Wro + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) @@ -1260,10 +1224,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // h = o * actH(c) // // equations (peephole connections are present) -// // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi -// // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf -// // zg = x ^ Wxg + hI ^ Wrg + bg -// // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo +// // zi = x × Wxi + hI × Wri + cI * Wpi + bi +// // zf = x × Wxf + hI × Wrf + cI * Wpf + bf +// // zg = x × Wxg + hI × Wrg + bg +// // zo = x × Wxo + hI × Wro + c * Wpo + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) @@ -1333,13 +1297,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // oFactor = *dLdh*dhdzo [bS, nOut] // // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0; -// // tempIFE = dcdzi^WriT + dcdzf^WrfT + dcdzg^WrgT -// // tempO = dhdzo^WroT +// // tempIFE = dcdzi×WriT + dcdzf×WrfT + dcdzg×WrgT +// // tempO = dhdzo×WroT // // dhIdcI = dhdc_from_previous_time_step -// // dLdx = iFactor^WxiT + fFactor^WxfT + eFactor^WxgT + oFactor^WxoT, [bS, nIn] -// // dLdhI = iFactor^WriT + fFactor^WrfT + eFactor^WrgT + oFactor^WroT, [bS, nOut] +// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn] +// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut] // // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] // // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index 3a2d173b5..29c434865 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -42,7 +42,7 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra ////////////////////////////////////////////////////////////////////////// void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdc, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index cee574dec..4052e260d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -2084,11 +2084,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2097,6 +2097,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); @@ -2113,12 +2114,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::SUM, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } @@ -2131,63 +2132,6 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) { const int nIn = 2; const int nOut = 3; - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = false; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); - - ASSERT_TRUE(isGradCorrect); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 3; - const int dataFormat = 1; // [bS,sL,nIn] const int directionMode = 0; // forward const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates @@ -2199,11 +2143,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = false; // output at last time step + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2233,13 +2177,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { const int sL = 4; const int bS = 3; @@ -2258,10 +2202,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2272,6 +2216,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2286,18 +2232,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { const int sL = 3; const int bS = 2; @@ -2315,11 +2261,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2328,7 +2274,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2343,18 +2291,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { const int sL = 3; const int bS = 2; @@ -2373,10 +2321,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2387,6 +2335,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2401,18 +2351,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { const int sL = 3; const int bS = 2; @@ -2430,11 +2380,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2444,7 +2394,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2459,18 +2411,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { const int sL = 3; const int bS = 2; @@ -2489,10 +2447,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2503,6 +2461,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2517,18 +2477,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { const int sL = 3; const int bS = 2; @@ -2547,10 +2513,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2561,6 +2527,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2575,12 +2543,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 5fffa73c5..3d86cd92b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -1904,6 +1904,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) { const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0}); ASSERT_TRUE(isGradCorrect); } + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { @@ -1922,3 +1923,68 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { ASSERT_TRUE(isGradCorrect); } + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::FLOAT32); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {3*nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {sL, bS, nOut}, {-1.681847, -1.062565, -0.443283, 0.175998,0.837823, 1.488041, 2.13826 , 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, + 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, 0.135444, 0.397604,0.710558, 1.002922, 1.295287, 1.587651}, sd::DataType::FLOAT32); + + Wx = 0.003; + Wh = 0.006; + b = 0.5; + + NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + + sd::ops::gru op; + auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* h = results.at(0); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_bp_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {3*nOut}, sd::DataType::DOUBLE); + + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); + + Wx.linspace(1,-0.1); + Wh.linspace(0.2,0.2); + b.linspace(1,-0.15); + + const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {}); + + sd::ops::gru opFF; + sd::ops::gru_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +}