Shyrma gru bp (#377)

* - update gru ff op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation and testing gru_bp op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - neglect dependencies between dLdh/dLdhLast/dLdcLast in lstmLayer backprop

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2020-04-16 08:09:04 +03:00 committed by GitHub
parent 217fa433d4
commit 4247718f61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 968 additions and 1135 deletions

View File

@ -50,10 +50,9 @@ class ND4J_EXPORT GradCheck {
* whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays
* IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.}
* loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM
* outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent
*/
static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector<int>& outArrsFFIdx = {});
const std::vector<bool>& whatArrsToCheck = std::vector<bool>(), const std::vector<double>& IdxRange = {0., 1.}, const LossFunc loss = SUM);
};

View File

@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
//////////////////////////////////////////////////////////////////////////
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss, const std::vector<int>& outArrsFFIdx) {
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss) {
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
@ -82,23 +82,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
int numOutArrs = outArrsFF.size();
double scorePlus = 0.;
if(!outArrsFFIdx.empty()) {
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0);
}
}
else {
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0);
}
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0);
}
// subtract epsilon, feed forward
@ -106,23 +95,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
outArrsFF = opFF.execute(argsHolderFF);
double scoreMinus = 0.;
if(!outArrsFFIdx.empty()) {
for(const auto& k : outArrsFFIdx) { // loop through independent output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0);
}
}
else {
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0);
}
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0);
}
// restore initial element value

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -15,7 +16,7 @@
******************************************************************************/
//
// created by Yurii Shyrma on 15.02.2018
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <system/op_boilerplate.h>
@ -30,83 +31,157 @@ namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // input [time x bS x iS]
auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x nU]
auto x = INPUT_VARIABLE(0); // input [time, bS, nIn]
auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut]
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU]
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU]
auto b = INPUT_VARIABLE(4); // biases, [3*nU]
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut]
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut]
auto b = INPUT_VARIABLE(4); // biases, [3*nOut]
auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x nU], that is per each time step
auto h = OUTPUT_VARIABLE(0); // cell outputs [time, bS, nOut], that is per each time step
const int rank = x->rankOf(); // = 3
const int time = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int iS = x->sizeAt(2);
const int nU = h0->sizeAt(1);
const int bS = x->sizeAt(1);
const int nIn = x->sizeAt(2);
const int nOut = hI->sizeAt(1);
const std::vector<Nd4jLong> h0CorrectShape = {bS, nU};
const std::vector<Nd4jLong> wxCorrectShape = {iS, 3*nU};
const std::vector<Nd4jLong> whCorrectShape = {nU, 3*nU};
const std::vector<Nd4jLong> bCorrectShape = {3*nU};
const std::vector<Nd4jLong> h0CorrectShape = {bS, nOut};
const std::vector<Nd4jLong> wxCorrectShape = {nIn, 3*nOut};
const std::vector<Nd4jLong> whCorrectShape = {nOut, 3*nOut};
const std::vector<Nd4jLong> bCorrectShape = {3*nOut};
REQUIRE_TRUE(h0->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
helpers::gruTimeLoop(block.launchContext(), x, h0, Wx, Wh, b, h);
helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(gru) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_TYPES(gru) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(gru) {
const auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize]
const auto h0ShapeInfo = inputShape->at(1); // initial cell output [bS x numUnits], that is at time step t=0
const auto WxShapeInfo = inputShape->at(2); // input-to-hidden weights, [inSize x 3*numUnits]
const auto WhShapeInfo = inputShape->at(3); // hidden-to-hidden weights, [numUnits x 3*numUnits]
const auto bShapeInfo = inputShape->at(4); // biases, [3*numUnits]
const int rank = shape::rank(xShapeInfo); // = 3
const auto time = xShapeInfo[1];
const auto bS = xShapeInfo[2];
const auto inSize = xShapeInfo[3];
const auto numUnits = h0ShapeInfo[2];
auto x = INPUT_VARIABLE(0); // input [time, bS, nIn]
auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut]
const std::vector<Nd4jLong> h0CorrectShape = {bS, numUnits};
const std::vector<Nd4jLong> wxCorrectShape = {inSize, 3*numUnits};
const std::vector<Nd4jLong> whCorrectShape = {numUnits, 3*numUnits};
const std::vector<Nd4jLong> bCorrectShape = {3*numUnits};
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut]
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut]
auto b = INPUT_VARIABLE(4); // biases, [3*nOut]
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
const int time = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int nIn = x->sizeAt(2);
const int nOut = hI->sizeAt(1);
const std::vector<Nd4jLong> h0CorrectShape = {bS, nOut};
const std::vector<Nd4jLong> wxCorrectShape = {nIn, 3*nOut};
const std::vector<Nd4jLong> whCorrectShape = {nOut, 3*nOut};
const std::vector<Nd4jLong> bCorrectShape = {3*nOut};
// evaluate output shapeInfo
Nd4jLong *hShapeInfo(nullptr);
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
hShapeInfo[0] = rank;
hShapeInfo[1] = time;
hShapeInfo[2] = bS;
hShapeInfo[3] = numUnits;
ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo));
auto* hShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(hI->dataType(), hI->ordering(), {time, bS, nOut});
return SHAPELIST(hShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(gru_bp, 6, 5, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // input [time, bS, nIn]
auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut]
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut]
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut]
auto b = INPUT_VARIABLE(4); // biases, [3*nOut]
auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut]
auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn]
auto dLdhI = OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut]
auto dLdWx = OUTPUT_NULLIFIED(2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut]
auto dLdWh = OUTPUT_NULLIFIED(3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut]
auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut]
const int time = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int nIn = x->sizeAt(2);
const int nOut = hI->sizeAt(1);
const std::vector<Nd4jLong> h0CorrectShape = {bS, nOut};
const std::vector<Nd4jLong> wxCorrectShape = {nIn, 3*nOut};
const std::vector<Nd4jLong> whCorrectShape = {nOut, 3*nOut};
const std::vector<Nd4jLong> bCorrectShape = {3*nOut};
const std::vector<Nd4jLong> hCorrectShape = {time, bS, nOut};
REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, dLdhI, dLdWx, dLdWh, dLdb);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(gru_bp) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(gru_bp) {
auto x = INPUT_VARIABLE(0); // input [time, bS, nIn]
auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut]
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut]
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut]
auto b = INPUT_VARIABLE(4); // biases, [3*nOut]
auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut]
const int time = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int nIn = x->sizeAt(2);
const int nOut = hI->sizeAt(1);
const std::vector<Nd4jLong> h0CorrectShape = {bS, nOut};
const std::vector<Nd4jLong> wxCorrectShape = {nIn, 3*nOut};
const std::vector<Nd4jLong> whCorrectShape = {nOut, 3*nOut};
const std::vector<Nd4jLong> bCorrectShape = {3*nOut};
const std::vector<Nd4jLong> hCorrectShape = {time, bS, nOut};
REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
Nd4jLong* dLdxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), x->getShapeInfo());
Nd4jLong* dLdhIShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), hI->getShapeInfo());
Nd4jLong* dLdWxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wx->getShapeInfo());
Nd4jLong* dLdWhShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wh->getShapeInfo());
Nd4jLong* dLdbShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), b->getShapeInfo());
return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo);
}
}
}

View File

@ -161,7 +161,7 @@ CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) {
REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str());
REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc);
helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc);
return Status::OK();
}

View File

@ -727,12 +727,10 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) {
dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0}));
}
// FIXME looks like sum (directionMode == 2) is impossible for backprop
if(dLdh) {
if(directionMode == 2) { // sum
REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !");
// dLdhFwd = dLdh;
// dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content
dLdhFwd = dLdh;
dLdhBwd = dLdh;
}
else if(directionMode == 3) { // concat
dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0}));
@ -744,21 +742,20 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) {
}
}
NDArray dLdxBwd = dLdx->ulike();
// FIXME - following two calls are independent and may run in different streams
helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd);
NDArray dLdxBwd = dLdx->ulike();
helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd);
*dLdx += dLdxBwd;
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd;
delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd;
delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd;
delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd;
delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd;
if(dLdhFwd != dLdh)
delete dLdhFwd;
if(!(dLdh && directionMode == 2)) { delete dLdhFwd; delete dLdhBwd; }
}
return Status::OK();

View File

@ -293,7 +293,7 @@ CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);
helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
return Status::OK();
}

View File

@ -345,6 +345,10 @@ namespace ops {
DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0);
#endif
#if NOT_EXCLUDED(OP_gru)
DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0);
#endif
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operation "static RNN time sequences" with peep hole connections:

View File

@ -1,421 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black
//
// implementation of gated Recurrent Unit cell
// (cf. https://arxiv.org/abs/1406.1078).
// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio
// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation"
#include <ops/declarable/helpers/gru.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/transforms.h>
#include <helpers/MmulHelper.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
const NDArray* b, const NDArray* bc,
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
//Inputs:
// x input [bS, iS], iS - input size
// hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
// W RU weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [iS+nU, nU] - cell gate
// b r and u biases, [2*nU] - reset and update gates
// bc c biases, [nU] - cell gate
//Outputs:
// r Reset gate output [bS, nU]
// u Update gate output [bS, nU]
// c Cell gate output [bS, nU]
// h current cell output [bS, nU]
/***************************************************************************************/
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
const int bS = x->sizeAt(0);
const int iS = x->sizeAt(1);
const int nU = hLast->sizeAt(1);
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
NDArray br = (*b)({0, nU}); // [nU]
NDArray bu = (*b)({nU, 2*nU}); // [nU]
// × means matrix multipication
// * means element-wise product or so called Hadamard product
// reset gate
r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r->applyTransform(transform::Sigmoid, *r);
// update gate
u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
u->applyTransform(transform::Sigmoid, *u);
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c->applyTransform(transform::Tanh, *c);
NDArray temp = 1.f - *c * *c;
// cell output
h->assign(*u * *hLast + (1.f - *u) * *c);
/***************************************************************************************/
/*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
/***************************************************************************************/
/*
//Concat inputs: x + hLast : [bs, iS + nU]
NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
//mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
// m += *bru;
m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
r->assign(m({0,0, 0, nU}));
u->assign(m({0,0, nU, 2*nU}));
// hLast = hLast * r
xhConcat({0,0, iS, iS+nU}) *= *r;
//c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
*c += *bc;
c->applyTransform(transform::Tanh);
//Output: h = (1-u).*c + u .* hPrev
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast
auto temp = (1.0f - *u);
temp *= (*c);
(*h) += temp;
*/
}
//////////////////////////////////////////////////////////////////////////
void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
// x input [time, bS, iS]
// hLast initial cell output (at time step = 0) [bS, nU]
// Wx input-to-hidden weights, [iS, 3*nU]
// Wh hidden-to-hidden weights, [nU, 3*nU]
// b biases, [3*nU]
// h is cell outputs at each time step [time, bS, nU]
const int time = x->sizeAt(0);
NDArray ht_1(*hLast);
// loop through time steps
for (int t = 0; t < time; ++t) {
auto xt = (*x)({t,t+1, 0,0, 0,0});
auto ht = (*h)({t,t+1, 0,0, 0,0});
// helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
// ht_1.assign(ht);
}
}
//////////////////////////////////////////////////////////////////////////
void gruCellBP(sd::LaunchContext* context,
const NDArray* x, const NDArray* hLast,
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhLast,
NDArray* dLdW, NDArray* dLdWc,
NDArray* dLdb, NDArray* dLdbc) {
//Inputs:
// x input [bS, iS]
// hLast previous cell output [bS, nU], that is at previous time step t-1
// W weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [iS+nU, nU] - cell gate
// b r and u biases, [2*nU] - reset and update gates
// bc c biases, [nU] - cell gate
// dLdr gradient wrt reset gate, [bS, nU]
// dLdu gradient wrt update gate, [bS, nU]
// dLdc gradient wrt cell state, [bS, nU]
// dLdh gradient wrt current cell output, [bS, nU]
//Outputs:
// dLdx gradient wrt x, [bS, iS],
// dLdhLast gradient wrt hLast, [bS, nU]
// dLdW gradient wrt W, [iS+nU, 2*nU]
// dLdWc gradient wrt Wc, [iS+nU, nU]
// dLdb gradient wrt bru [2*nU]
// dLdbc gradient wrt bc [nU]
// * means element-wise product or so called Hadamard product
// × means matrix multiplication
/************************************************************************************************/
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
const int bS = x->sizeAt(0);
const int iS = x->sizeAt(1);
const int nU = hLast->sizeAt(1);
NDArray xT = x->transpose(); // [iS, bS]
NDArray hLastT = hLast->transpose(); // [nU, bS]
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
NDArray br = (*b)({0, nU}); // [nU]
NDArray bu = (*b)({nU, 2*nU}); // [nU]
NDArray WrxT = Wrx.transpose(); // [nU, iS]
NDArray WuxT = Wux.transpose(); // [nU, iS]
NDArray WrhT = Wrh.transpose(); // [nU, nU]
NDArray WuhT = Wuh.transpose(); // [nU, nU]
NDArray WcxT = Wcx.transpose(); // [nU, iS]
NDArray WchT = Wch.transpose(); // [nU, nU]
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
// ***** feed forward step ***** //
// reset gate
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r.applyTransform(transform::Sigmoid, r);
// update gate
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
u.applyTransform(transform::Sigmoid, u);
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c.applyTransform(transform::Tanh, c);
// h = (1 - u) * c + u * hPrev
// ***** back prop step ***** //
// notations:
// Zr = x × Wrx + hLast × Wrh + br
// Zu = x × Wux + hLast × Wuh + bu
// Sr = sigmoid(Zr)
// Su = sigmoid(Zu)
// Zc = x × Wcx + (r * hlast) × Wch + bc
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
// = dLdx_u + dLdx_c
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
// dZcdr = (... * hLast) × WchT
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
// drdx = drdZr * dZrdx
// dZrdx = ... × WrxT
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
// dZrdWrx = xT × ...
// finally dLdWrx = xT × (dLdr * drdZr)
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
// dZrdWrh = hLastT × ...
// finally dLdWrh = hLastT × (dLdr * drdZr)
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
// dZudWux = xT × ...
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
// dZudWuh = hLastT × ...
// finally dLdWuh = hLastT × (dLdu * dudZu)
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
// dZcdWcx = xT × ...
// finally dLdWcx = xT × (dLdc * dcdZc)
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
// dZcdWch = (r*hLast)^T × ...
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
// = dLdr * drdZr * dZrdbr
// dZrdbr = 1
// finally dLdbr = dLdr * drdZr
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
// dZudbu = 1
// finally dLdbu = dLdu * dudZu
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
// dZcdbc = 1
// finally dLdbc = dLdc * dcdZc
NDArray dhdc = 1.f - u; // [bS, nU]
NDArray dhdu = *hLast - c; // [bS, nU]
NDArray dudZu = u * dhdc; // [bS, nU]
NDArray drdZr = r * (1.f - r); // [bS, nU]
NDArray dcdZc = 1.f - c * c; // [bS, nU]
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU]
dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU]
dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU]
}
// //////////////////////////////////////////////////////////////////////////
// FIXME - gruTimeLoopBP is not correct
// template <typename T>
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU]
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU]
// NDArray<T>* b = inArrs[4]; // biases, [3*nU]
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU]
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU]
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU]
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nU]
// const Nd4jLong time = x->sizeAt(0);
// const Nd4jLong bS = x->sizeAt(1);
// const Nd4jLong iS = x->sizeAt(2);
// const Nd4jLong nU = hi->sizeAt(1);
// NDArray<T> h(hi->ordering(), {time, bS, nU}); // feed forward output
// // first step, time = 0, feed forward
// NDArray<T> x0 = (*x)({{0,1}, {}, {}});
// NDArray<T> hLast = h({{0,1}, {}, {}});
// helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &hLast);
// // first step, time = 0, back prop
// NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}});
// NDArray<T> dLdhLast = (*dLdh)({{0,1}, {}, {}});
// helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
// // loop through the rest time steps
// for (Nd4jLong t = time-1; t > 0; --t) {
// for (Nd4jLong t = 1; t < time; ++t) {
// NDArray<T> xt = (*x)({{t,t+1}, {}, {}});
// NDArray<T> ht = h({{t,t+1}, {}, {}});
// NDArray<T> ht_1 = h({{t-1,t}, {}, {}});
// NDArray<T> dLdxt = (*dLdx)({{t,t+1}, {}, {}});
// NDArray<T> dLdht = (*dLdh)({{t,t+1}, {}, {}});
// NDArray<T> dLdWxt_1 = dLdWx;
// NDArray<T> dLdWht_1 = dLdWh;
// NDArray<T> dLdbt_1 = dLdb;
// // feed forward, calculation of ht
// helpers::gruCell<T>({&xt, &ht_1, Wx, Wh, b}, &ht);
// // back prop
// helpers::gruCellBP<T>({&xt, &ht_1, Wx, Wh, b, &dLdht, &dLdWxt_1, &dLdWht_1, &dLdbt_1}, {&dLdxt, nullptr, dLdWx, dLdWh, dLdb});
// }
// }
}
}
}

View File

@ -1,365 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018
//
// implementation of gated Recurrent Unit cell
// (cf. https://arxiv.org/abs/1406.1078).
// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio
// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation"
#include<ops/declarable/helpers/gru.h>
#include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/transforms.h>
#include <helpers/MmulHelper.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
const NDArray* b, const NDArray* bc,
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
//Inputs:
// x input [bS, iS], iS - input size
// hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
// W RU weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [iS+nU, nU] - cell gate
// b r and u biases, [2*nU] - reset and update gates
// bc c biases, [nU] - cell gate
//Outputs:
// r Reset gate output [bS, nU]
// u Update gate output [bS, nU]
// c Cell gate output [bS, nU]
// h current cell output [bS, nU]
/***************************************************************************************/
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
const int bS = x->sizeAt(0);
const int iS = x->sizeAt(1);
const int nU = hLast->sizeAt(1);
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
NDArray br = (*b)({0, nU}); // [nU]
NDArray bu = (*b)({nU, 2*nU}); // [nU]
// × means matrix multipication
// * means element-wise product or so called Hadamard product
// reset gate
r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r->applyTransform(transform::Sigmoid, *r);
// update gate
u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
u->applyTransform(transform::Sigmoid, *u);
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c->applyTransform(transform::Tanh, *c);
NDArray temp = 1.f - *c * *c;
// cell output
h->assign(*u * *hLast + (1.f - *u) * *c);
/***************************************************************************************/
/*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
/***************************************************************************************/
/*
//Concat inputs: x + hLast : [bs, iS + nU]
NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
//mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
// m += *bru;
m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
r->assign(m({0,0, 0, nU}));
u->assign(m({0,0, nU, 2*nU}));
// hLast = hLast * r
xhConcat({0,0, iS, iS+nU}) *= *r;
//c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
*c += *bc;
c->applyTransform(transform::Tanh);
//Output: h = (1-u).*c + u .* hPrev
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast
auto temp = (1.0f - *u);
temp *= (*c);
(*h) += temp;
*/
}
//////////////////////////////////////////////////////////////////////////
void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
// x input [time, bS, iS]
// hLast initial cell output (at time step = 0) [bS, nU]
// Wx input-to-hidden weights, [iS, 3*nU]
// Wh hidden-to-hidden weights, [nU, 3*nU]
// b biases, [3*nU]
// h is cell outputs at each time step [time, bS, nU]
const int time = x->sizeAt(0);
NDArray ht_1(*hLast);
// loop through time steps
for (int t = 0; t < time; ++t) {
auto xt = (*x)({t,t+1, 0,0, 0,0});
auto ht = (*h)({t,t+1, 0,0, 0,0});
// helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
// ht_1.assign(ht);
}
}
//////////////////////////////////////////////////////////////////////////
void gruCellBP(sd::LaunchContext* context,
const NDArray* x, const NDArray* hLast,
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhLast,
NDArray* dLdW, NDArray* dLdWc,
NDArray* dLdb, NDArray* dLdbc) {
//Inputs:
// x input [bS, iS]
// hLast previous cell output [bS, nU], that is at previous time step t-1
// W weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [iS+nU, nU] - cell gate
// b r and u biases, [2*nU] - reset and update gates
// bc c biases, [nU] - cell gate
// dLdr gradient wrt reset gate, [bS, nU]
// dLdu gradient wrt update gate, [bS, nU]
// dLdc gradient wrt cell state, [bS, nU]
// dLdh gradient wrt current cell output, [bS, nU]
//Outputs:
// dLdx gradient wrt x, [bS, iS],
// dLdhLast gradient wrt hLast, [bS, nU]
// dLdW gradient wrt W, [iS+nU, 2*nU]
// dLdWc gradient wrt Wc, [iS+nU, nU]
// dLdb gradient wrt bru [2*nU]
// dLdbc gradient wrt bc [nU]
// * means element-wise product or so called Hadamard product
// × means matrix multiplication
/************************************************************************************************/
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
const int bS = x->sizeAt(0);
const int iS = x->sizeAt(1);
const int nU = hLast->sizeAt(1);
NDArray xT = x->transpose(); // [iS, bS]
NDArray hLastT = hLast->transpose(); // [nU, bS]
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
NDArray br = (*b)({0, nU}); // [nU]
NDArray bu = (*b)({nU, 2*nU}); // [nU]
NDArray WrxT = Wrx.transpose(); // [nU, iS]
NDArray WuxT = Wux.transpose(); // [nU, iS]
NDArray WrhT = Wrh.transpose(); // [nU, nU]
NDArray WuhT = Wuh.transpose(); // [nU, nU]
NDArray WcxT = Wcx.transpose(); // [nU, iS]
NDArray WchT = Wch.transpose(); // [nU, nU]
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
// ***** feed forward step ***** //
// reset gate
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r.applyTransform(transform::Sigmoid, r);
// update gate
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
u.applyTransform(transform::Sigmoid, u);
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c.applyTransform(transform::Tanh, c);
// h = (1 - u) * c + u * hPrev
// ***** back prop step ***** //
// notations:
// Zr = x × Wrx + hLast × Wrh + br
// Zu = x × Wux + hLast × Wuh + bu
// Sr = sigmoid(Zr)
// Su = sigmoid(Zu)
// Zc = x × Wcx + (r * hlast) × Wch + bc
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
// = dLdx_u + dLdx_c
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
// dZcdr = (... * hLast) × WchT
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
// drdx = drdZr * dZrdx
// dZrdx = ... × WrxT
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
// dZrdWrx = xT × ...
// finally dLdWrx = xT × (dLdr * drdZr)
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
// dZrdWrh = hLastT × ...
// finally dLdWrh = hLastT × (dLdr * drdZr)
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
// dZudWux = xT × ...
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
// dZudWuh = hLastT × ...
// finally dLdWuh = hLastT × (dLdu * dudZu)
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
// dZcdWcx = xT × ...
// finally dLdWcx = xT × (dLdc * dcdZc)
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
// dZcdWch = (r*hLast)^T × ...
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
// = dLdr * drdZr * dZrdbr
// dZrdbr = 1
// finally dLdbr = dLdr * drdZr
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
// dZudbu = 1
// finally dLdbu = dLdu * dudZu
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
// dZcdbc = 1
// finally dLdbc = dLdc * dcdZc
NDArray dhdc = 1.f - u; // [bS, nU]
NDArray dhdu = *hLast - c; // [bS, nU]
NDArray dudZu = u * dhdc; // [bS, nU]
NDArray drdZr = r * (1.f - r); // [bS, nU]
NDArray dcdZc = 1.f - c * c; // [bS, nU]
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU]
dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU]
dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU]
}
}
}
}

View File

@ -31,10 +31,26 @@ namespace helpers {
const NDArray* bru, const NDArray* bc,
NDArray* r, NDArray* u, NDArray* c, NDArray* h);
void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, const NDArray* b,
NDArray* gates, NDArray* h);
void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h);
void gruCellBP(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc);
void gruCellBp(sd::LaunchContext* context,
const NDArray* x, const NDArray* hLast,
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhLast,
NDArray* dLdW, NDArray* dLdWc,
NDArray* dLdb, NDArray* dLdbc);
void gruCellBp(sd::LaunchContext* context,
const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates,
NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb);
void gruTimeLoopBp(sd::LaunchContext * context,
const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb);
}
}
}

View File

@ -0,0 +1,546 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* ThnIn program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which nIn available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permnInsions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black
//
// implementation of gated Recurrent Unit cell
// (cf. https://arxiv.org/abs/1406.1078).
// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio
// "Learning Phrase Representations using RNN Encoder-Decoder for StatnIntical Machine Translation"
#include <ops/declarable/helpers/gru.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/transforms.h>
#include <helpers/MmulHelper.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* W, const NDArray* Wc,
const NDArray* b, const NDArray* bc,
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
//Inputs:
// x input [bS, nIn], nIn - input size
// hI previous cell output [bS, nOut], that is at previous time step t-1, nOut - number of units
// W RU weights - [nIn+nOut, 2*nOut] - reset and update gates
// Wc C weights - [nIn+nOut, nOut] - cell gate
// b r and u biases, [2*nOut] - reset and update gates
// bc c biases, [nOut] - cell gate
//Outputs:
// r Reset gate output [bS, nOut]
// u Update gate output [bS, nOut]
// c Cell gate output [bS, nOut]
// h current cell output [bS, nOut]
/***************************************************************************************/
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
const int bS = x->sizeAt(0);
const int nIn = x->sizeAt(1);
const int nOut = hI->sizeAt(1);
NDArray Wrx = (*W)({0,nIn, 0,nOut}); // [nIn, nOut]
NDArray Wux = (*W)({0,nIn, nOut,2*nOut}); // [nIn, nOut]
NDArray Wrh = (*W)({nIn,nIn+nOut, 0,nOut}); // [nOut, nOut]
NDArray Wuh = (*W)({nIn,nIn+nOut, nOut,2*nOut}); // [nOut, nOut]
NDArray Wcx = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nOut]
NDArray Wch = (*Wc)({nIn,nIn+nOut, 0,0}); // updates cell weights [nOut, nOut]
NDArray br = (*b)({0, nOut}); // [nOut]
NDArray bu = (*b)({nOut, 2*nOut}); // [nOut]
// × means matrix multipication
// * means element-wise product or so called Hadamard product
// reset gate
r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut]
r->applyTransform(transform::Sigmoid, *r);
// update gate
u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut]
u->applyTransform(transform::Sigmoid, *u);
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut]
c->applyTransform(transform::Tanh, *c);
// cell output
h->assign(*u * *hI + (1.f - *u) * *c);
}
//////////////////////////////////////////////////////////////////////////
void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b,
NDArray* gates, NDArray* h) {
//Inputs:
// x input [bS, nIn]
// hI previous cell output [bS, nOut], that is at previous time step t-1
// Wx weights for x - [nIn, 3*nOut]
// Wh weights for h - [nOut, 3*nOut]
// b biases [3*nOut]
// 3*nOut means following sequence: reset, update, cell
//Outputs:
// gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + cell gate [bS, nOut]
// h current cell output [bS, nOut]
// formulas:
// zr = x × Wxr + hI × Whr + br
// zu = x × Wxu + hI × Whu + bu
// r = sigmoid(zr)
// u = sigmoid(zu)
// zc = x × Wxc + (r * hI) × Whc + bc
// c = tanh(zc)
// h = (1-u)*c + u*hI
const int bS = x->sizeAt(0);
const int nIn = x->sizeAt(1);
const int nOut = hI->sizeAt(1);
NDArray temp = gates->ulike();
MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut]
temp += *b;
MmulHelper::mmul(hI, Wh, gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut]
NDArray ru = (*gates)({0,0, 0,2*nOut}); // [bS, 2*nOut]
NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut]
NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut]
NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut]
// reset and update gates
ru += temp({0,0, 0,2*nOut});
ru.applyTransform(transform::Sigmoid, ru);
// cell gate
c.assign(c*r + temp({0,0, 2*nOut, 3*nOut}));
c.applyTransform(transform::Tanh, c);
// cell output
h->assign(u * *hI + (1.f - u) * c);
}
//////////////////////////////////////////////////////////////////////////
void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
// sL means time steps
// x input [sL, bS, nIn]
// hI initial cell output (at time step = 0) [bS, nOut]
// Wx input-to-hidden weights, [nIn, 3*nOut]
// Wh hidden-to-hidden weights, [nOut, 3*nOut]
// b biases, [3*nOut]
// h cell outputs at each time step [sL, bS, nOut]
const int sL = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int nOut = hI->sizeAt(1);
NDArray gates(h->ordering(), {bS, 3*nOut}, h->dataType(), context);
auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn]
auto hSet = h->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
// time loop
for (int t = 0; t < sL; ++t)
gruCell(context, xSet.at(t), t == 0 ? hI : hSet.at(t-1), Wx, Wh, b, &gates, hSet.at(t));
}
//////////////////////////////////////////////////////////////////////////
void gruCellBp(sd::LaunchContext* context,
const NDArray* x, const NDArray* hLast,
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhLast,
NDArray* dLdW, NDArray* dLdWc,
NDArray* dLdb, NDArray* dLdbc) {
//Inputs:
// x input [bS, iS]
// hLast previous cell output [bS, nU], that is at previous time step t-1
// W weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [iS+nU, nU] - cell gate
// b r and u biases, [2*nU] - reset and update gates
// bc c biases, [nU] - cell gate
// dLdr gradient wrt reset gate, [bS, nU]
// dLdu gradient wrt update gate, [bS, nU]
// dLdc gradient wrt cell state, [bS, nU]
// dLdh gradient wrt current cell output, [bS, nU]
//Outputs:
// dLdx gradient wrt x, [bS, iS],
// dLdhLast gradient wrt hLast, [bS, nU]
// dLdW gradient wrt W, [iS+nU, 2*nU]
// dLdWc gradient wrt Wc, [iS+nU, nU]
// dLdb gradient wrt bru [2*nU]
// dLdbc gradient wrt bc [nU]
// * means element-wise product or so called Hadamard product
// × means matrix multiplication
/************************************************************************************************/
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
const int bS = x->sizeAt(0);
const int iS = x->sizeAt(1);
const int nU = hLast->sizeAt(1);
NDArray xT = x->transpose(); // [iS, bS]
NDArray hLastT = hLast->transpose(); // [nU, bS]
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
NDArray br = (*b)({0, nU}); // [nU]
NDArray bu = (*b)({nU, 2*nU}); // [nU]
NDArray WrxT = Wrx.transpose(); // [nU, iS]
NDArray WuxT = Wux.transpose(); // [nU, iS]
NDArray WrhT = Wrh.transpose(); // [nU, nU]
NDArray WuhT = Wuh.transpose(); // [nU, nU]
NDArray WcxT = Wcx.transpose(); // [nU, iS]
NDArray WchT = Wch.transpose(); // [nU, nU]
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
// ***** feed forward step ***** //
// reset gate
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r.applyTransform(transform::Sigmoid, r);
// update gate
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
u.applyTransform(transform::Sigmoid, u);
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c.applyTransform(transform::Tanh, c);
// h = (1 - u) * c + u * hPrev
// ***** back prop step ***** //
// notations:
// Zr = x × Wrx + hLast × Wrh + br
// Zu = x × Wux + hLast × Wuh + bu
// Sr = sigmoid(Zr)
// Su = sigmoid(Zu)
// Zc = x × Wcx + (r * hlast) × Wch + bc
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
// = dLdx_u + dLdx_c
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
// dZcdr = (... * hLast) × WchT
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
// drdx = drdZr * dZrdx
// dZrdx = ... × WrxT
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
// dZrdWrx = xT × ...
// finally dLdWrx = xT × (dLdr * drdZr)
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
// dZrdWrh = hLastT × ...
// finally dLdWrh = hLastT × (dLdr * drdZr)
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
// dZudWux = xT × ...
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
// dZudWuh = hLastT × ...
// finally dLdWuh = hLastT × (dLdu * dudZu)
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
// dZcdWcx = xT × ...
// finally dLdWcx = xT × (dLdc * dcdZc)
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
// dZcdWch = (r*hLast)^T × ...
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
// = dLdr * drdZr * dZrdbr
// dZrdbr = 1
// finally dLdbr = dLdr * drdZr
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
// dZudbu = 1
// finally dLdbu = dLdu * dudZu
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
// dZcdbc = 1
// finally dLdbc = dLdc * dcdZc
NDArray dhdc = 1.f - u; // [bS, nU]
NDArray dhdu = *hLast - c; // [bS, nU]
NDArray dudZu = u * dhdc; // [bS, nU]
NDArray drdZr = r * (1.f - r); // [bS, nU]
NDArray dcdZc = 1.f - c * c; // [bS, nU]
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU]
dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU]
dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU]
}
//////////////////////////////////////////////////////////////////////////
void gruCellBp(sd::LaunchContext* context,
const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates,
NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
//Inputs:
// x input [bS, nIn]
// hI previous cell output [bS, nOut], that nIn at previous time step t-1
// Wx input-to-hidden weights - [nIn, 3*nOut]
// Wh hidden-to-hidden weights - [nOut, 3*nOut]
// b biases, [3*nOut] - reset and update gates
// dLdh gradient vs. ff output, [bS, nOut]
//Outputs:
// dLdx gradient vs. x, [bS, nIn],
// dLdhI gradient vs. hI, [bS, nOut]
// dLdWx gradient vs. W, [nIn, 3*nOut]
// dLdWh gradient vs. Wc, [nOut, 3*nOut]
// dLdb gradient vs. b [3*nOut]
// 3*nOut means following sequence: reset, update, cell
// * means element-wnIne product or so called Hadamard product
// × means matrix multiplication
// formulas:
// zr = x × Wxr + hI × Whr + br
// zu = x × Wxu + hI × Whu + bu
// r = sigmoid(zr)
// u = sigmoid(zu)
// zc = x × Wxc + (r * hI) × Whc + bc
// c = tanh(zc)
// h = (1-u)*c + u*hI
// dLdhI += dLdh; [bS, nOut]
// dhdc = 1 - u [bS, nOut]
// dhdu = -c + hI [bS, nOut]
// dcdzc = 1 - c*c; [bS, nOut]
// dudzu = u*(1-u) [bS, nOut]
// drdzr = r(1-r) [bS, nOut]
// dzcdr = (...*hI × WhcT) [bS, nOut]
// dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut]
// dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut]
// dLdzc = dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut]
// dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, nIn] + ... = [bS, nIn]
// dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, nOut] + ... = [bS, nOut]
// dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxc = xT × dLdzc [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWhu = xT × dLdzu [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWhc = (r*hI)T × dLdzc [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
const int nOut = hI->sizeAt(1);
NDArray dLdz = gates->ulike(); // [bS, 3*nOut]
NDArray dLdzru = dLdz({0,0, 0,2*nOut}); // [bS, 2*nOut]
NDArray dLdzr = dLdz({0,0, 0,nOut}); // [bS, nOut]
NDArray dLdzu = dLdz({0,0, nOut,2*nOut}); // [bS, nOut]
NDArray dLdzc = dLdz({0,0, 2*nOut,3*nOut}); // [bS, nOut]
NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut]
NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut]
NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut]
NDArray WhcT = (*Wh)({0,0, 2*nOut,3*nOut}).transpose();
if(dLdh)
*dLdhI += *dLdh;
NDArray temp1 = 1 - u; // [bS, nOut]
// dLdzc
dLdzc.assign(*dLdhI * temp1 * (1-c*c)); // [bS, nOut]
// dLdzu
dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut]
// dLdzr
NDArray temp2 = dLdzc * (*hI) * r *(1-r);
MmulHelper::mmul(&temp2, &WhcT, &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut]
// dLdx
NDArray WxT = Wx->transpose();
MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn]
// dLdWx
*dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut]
// dLdb
*dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut];
dLdzc *= r;
// dLdhI
NDArray WhT = Wh->transpose();
dLdhI->assign(*dLdhI*u + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut]
// dLdWr
*dLdWh += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut]
}
//////////////////////////////////////////////////////////////////////////
void gruTimeLoopBp(sd::LaunchContext * context,
const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
// sL means time steps
// x input [sL, bS, nIn]
// hI initial cell output (at time step = 0) [bS, nOut]
// Wx input-to-hidden weights, [nIn, 3*nOut]
// Wh hidden-to-hidden weights, [nOut, 3*nOut]
// b biases, [3*nOut]
// dLdh gradient vs. ff output, [sL, bS, nOut]
// dLdx gradient vs. x, [sL, bS, nIn],
// dLdhI gradient vs. hI, [bS, nOut]
// dLdWx gradient vs. W, [nIn, 3*nOut]
// dLdWh gradient vs. Wc, [nOut, 3*nOut]
// dLdb gradient vs. b [3*nOut]
const int sL = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int nOut = hI->sizeAt(1);
NDArray gates(x->ordering(), {sL, bS, 3*nOut}, dLdh->dataType(), x->getContext());
NDArray h(x->ordering(), {sL+1, bS, nOut}, dLdh->dataType(), x->getContext());
auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn]
auto dLdhSet = dLdh->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
auto hSet = h.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn]
hSet.at(0)->assign(hI);
// forward time loop
for (int t = 0; t < sL; ++t)
gruCell(context, xSet.at(t), hSet.at(t), Wx, Wh, b, gatesSet.at(t), hSet.at(t+1));
// backward time loop
for (int t = sL-1; t >= 0; --t)
gruCellBp(context, xSet.at(t), hSet.at(t), Wx, Wh, b, dLdhSet.at(t), gatesSet.at(t),
dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb);
}
}
}
}

View File

@ -189,54 +189,6 @@ static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
}
//////////////////////////////////////////////////////////////////////////
// x{M,K} x y{K,N} = z{M,N}, dzdy{K,N,M,N} - Jacobian derivative -> if x.rankOf() == 2
// x{K} x y{K,N} = z{N}, dzdy{K,N,N} - Jacobian derivative -> if x.rankOf() == 1
static NDArray mmulJacobianWeightsDeriv(const int nOut, const NDArray& x) {
std::vector<Nd4jLong> outShape = x.rankOf() == 1 ? std::vector<Nd4jLong>({x.sizeAt(0), nOut, nOut}) : std::vector<Nd4jLong>({x.sizeAt(1), nOut, x.sizeAt(0), nOut});
NDArray dzdy(x.ordering(), outShape, x.dataType(), x.getContext());
if(x.rankOf() == 1) {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
if(i1 == i2)
dzdy.p<double>(i0,i1,i2, x.e<double>(i0));
else
dzdy.p<double>(i0,i1,i2, 0);
}
}
}
};
samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1);
}
else {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (auto i3 = 0; i3 < dzdy.sizeAt(3); ++i3) {
if(i1 == i3)
dzdy.p<double>(i0,i1,i2,i3, x.e<double>(i2,i0));
else
dzdy.p<double>(i0,i1,i2,i3, 0);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1);
}
return dzdy;
}
//////////////////////////////////////////////////////////////////////////
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
@ -245,25 +197,25 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
NDArray* h, NDArray* c) {
// * -> means element-wise multiplication
// ^ -> means matrix multiplication
// × -> means matrix multiplication
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** the objective is to provide math-readable code **/
// equations (no peephole connections)
// it = σ(Wxi ^ xt + Wri ^ ht-1 + bi)
// ft = σ(Wxf ^ xt + Wrf ^ ht-1 + bf)
// c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc)
// it = σ(Wxi × xt + Wri × ht-1 + bi)
// ft = σ(Wxf × xt + Wrf × ht-1 + bf)
// c't = tanh(Wxc × xt + Wrc × ht-1 + bc)
// ct = ft * ct-1 + it * c't
// ot = σ(Wxo ^ xt + Wro ^ ht-1 + bo)
// ot = σ(Wxo × xt + Wro × ht-1 + bo)
// ht = ot * tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi ^ xt + Wri ^ ht-1 + Wpi * ct-1 + bi)
// ft = σ(Wxf ^ xt + Wrf ^ ht-1 + Wpf * ct-1 + bf)
// c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc)
// it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi)
// ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf)
// c't = tanh(Wxc × xt + Wrc × ht-1 + bc)
// ct = ft * ct-1 + it * c't
// ot = σ(Wxo ^ xt + Wro ^ ht-1 + Wpo * ct + bo)
// ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo)
// ht = ot * tanh(ct)
@ -399,7 +351,7 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
//////////////////////////////////////////////////////////////////////////
void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const NDArray* dLdh, const NDArray* dLdc,
const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) {
@ -407,10 +359,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con
/** the objective is to provide math-readable code **/
// equations (no peephole connections)
// zi = x ^ Wxi + hI ^ Wri + bi
// zf = x ^ Wxf + hI ^ Wrf + bf
// zg = x ^ Wxg + hI ^ Wrg + bg
// zo = x ^ Wxo + hI ^ Wro + bo
// zi = x × Wxi + hI × Wri + bi
// zf = x × Wxf + hI × Wrf + bf
// zg = x × Wxg + hI × Wrg + bg
// zo = x × Wxo + hI × Wro + bo
// i = act(zi)
// f = act(zf)
// g = actC(zg)
@ -419,10 +371,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con
// h = o * actH(c)
// equations (peephole connections are present)
// zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi
// zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf
// zg = x ^ Wxg + hI ^ Wrg + bg
// zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo
// zi = x × Wxi + hI × Wri + cI * Wpi + bi
// zf = x × Wxf + hI × Wrf + cI * Wpf + bf
// zg = x × Wxg + hI × Wrg + bg
// zo = x × Wxo + hI × Wro + c * Wpo + bo
// i = act(zi)
// f = act(zf)
// g = actC(zg)
@ -449,18 +401,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con
// params[11] - beta value for output activation
// INPUTS:
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
// hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr
// dLdc - loss derivative with respect to c, [bS, nOut] or [nOut] if seqLen != nullptr
// z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
// a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
// c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut]
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
// hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr
// dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr
// dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr
// z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
// a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
// c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut]
// OUTPUTS:
// dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr
@ -485,19 +438,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con
// dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut])
// dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut])
// dLdx = dLdzi^WxiT + dLdzf^WxfT + dLdzg^WxgT + dLdzo^WxoT, [bS, nIn]
// dLdhI = dLdzi^WriT + dLdzf^WrfT + dLdzg^WrgT + dLdzo^WroT, [bS, nOut]
// dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn]
// dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut]
// dLdcI = dLdcI*dcdcI, [bS, nOut]
// dLdWxi = xT^dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxf = xT^dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxg = xT^dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxo = xT^dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut]
// dLdWri = hIT^dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWrf = hIT^dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWrg = hIT^dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWro = hIT^dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut]
// dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
// dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
@ -563,10 +516,12 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con
if(dLdh)
*dLdhI += *dLdh;
if(dLdc)
*dLdcI += *dLdc;
else
*dLdcI += *dLdhI * dhdc;
if(dLdhL)
*dLdhI += *dLdhL;
if(dLdcL)
*dLdcI += *dLdcL;
*dLdcI += *dLdhI * dhdc;
dLdzi *= *dLdcI; // [bS, nOut](or[nOut])
dLdzf *= *dLdcI; // [bS, nOut](or[nOut])
@ -662,25 +617,27 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const std::vector<Nd4jLong> shapeOut = {bS, nOut};
const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType());
auto h0 = const_cast<NDArray*>(hI);
if(!hI) {
h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext());
h0->nullify();
}
auto c0 = const_cast<NDArray*>(cI);
if(!cI) {
c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext());
c0->nullify();
}
auto ct = cL;
if(!cL)
ct = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
ct = new NDArray(x->ordering(), shapeOut, type, x->getContext());
auto ht = hL;
if(!h && !hL)
ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
ht = new NDArray(x->ordering(), shapeOut, type, x->getContext());
// create sets of required (depends on seqLen presence) sub-arrays
std::vector<int> dims;
@ -989,17 +946,19 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const int nOut = Wx->sizeAt(-1) / 4;
const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType());
auto dLdh0 = dLdhI;
if(!hI)
dLdh0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically
dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically
auto dLdc0 = dLdcI;
if(!cI)
dLdc0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically
dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically
NDArray z(x->ordering(), {sL, bS, 4*nOut}, x->dataType(), x->getContext());
NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext());
NDArray a = z.ulike();
NDArray h(x->ordering(), {sL+1, bS, nOut}, x->dataType(), x->getContext());
NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext());
NDArray c = h.ulike();
// create sets of required (depends on seqLen presence) sub-arrays
@ -1041,9 +1000,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
if(dLdh)
dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut]
if(!dLdh && dLdhL)
if(dLdhL)
dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
if(!dLdh && !dLdhL)
if(dLdcL)
dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
}
@ -1054,13 +1013,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
if(!seqLen) { // seqLen is absent
if(hI)
h({0,1, 0,0, 0,0}).assign(hI);
hSet->at(0)->assign(hI);
else
h({0,1, 0,0, 0,0}).nullify();
hSet->at(0)->nullify();
if(cI)
c({0,1, 0,0, 0,0}).assign(cI);
cSet->at(0)->assign(cI);
else
c({0,1, 0,0, 0,0}).nullify();
cSet->at(0)->nullify();
// ff
for (int t = 0; t < sL; ++t)
@ -1068,9 +1027,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// bp
for (int t = sL-1; t >= 0; --t) {
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == sL-1 ? dLdhL : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-1 ? dLdcL : nullptr);
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdcc,
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr;
const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr;
const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr;
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL,
zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
}
}
@ -1086,13 +1046,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
}
if(hI)
h({0,1, e,e+1, 0,0}).assign(hISet->at(e));
hSet->at(e)->assign(hISet->at(e));
else
h({0,1, e,e+1, 0,0}).nullify();
hSet->at(e)->nullify();
if(cI)
c({0,1, e,e+1, 0,0}).assign(cISet->at(e));
cSet->at(e)->assign(cISet->at(e));
else
c({0,1, e,e+1, 0,0}).nullify();
cSet->at(e)->nullify();
// ff
for (int t = 0; t < limit; ++t)
@ -1102,9 +1062,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// bp
for (int t = limit-1; t >= 0; --t) {
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == limit-1 && dLdhL ? dLdhLSet->at(e) : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == limit-1 ? dLdcLSet->at(e) : nullptr);
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdcc,
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr;
const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr;
const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr;
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
}
@ -1119,13 +1080,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
if(!seqLen) { // backward or bidirectional, seqLen is absent
if(hI)
h({sL,sL+1, 0,0, 0,0}).assign(hI);
hSet->at(sL)->assign(hI);
else
h({sL,sL+1, 0,0, 0,0}).nullify();
hSet->at(sL)->nullify();
if(cI)
c({sL,sL+1, 0,0, 0,0}).assign(cI);
cSet->at(sL)->assign(cI);
else
c({sL,sL+1, 0,0, 0,0}).nullify();
cSet->at(sL)->nullify();
// ff
for (int t = sL-1; t >= 0; --t)
@ -1133,9 +1094,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// bp
for (int t = 0; t < sL; ++t) {
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == 0 ? dLdhL : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcL : nullptr);
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdcc,
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr;
const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr;
const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr;
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL,
zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
}
}
@ -1151,13 +1113,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
}
if(hI)
h({sL,sL+1, e,e+1, 0,0}).assign(hISet->at(e));
hSet->at(sL*bS + e)->assign(hISet->at(e));
else
h({sL,sL+1, e,e+1, 0,0}).nullify();
hSet->at(sL*bS + e)->nullify();
if(cI)
c({sL,sL+1, e,e+1, 0,0}).assign(cISet->at(e));
cSet->at(sL*bS + e)->assign(cISet->at(e));
else
c({sL,sL+1, e,e+1, 0,0}).nullify();
cSet->at(sL*bS + e)->nullify();
// ff
for (int t = sL - 1; t >= sL-limit; --t)
@ -1167,9 +1129,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// bp
for (int t = sL-limit; t < sL; ++t) {
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == sL-limit && dLdhL ? dLdhLSet->at(e) : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-limit ? dLdcLSet->at(e) : nullptr);
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc,
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr;
const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr;
const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr;
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
}
@ -1206,9 +1169,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// bp
for (int t = 0; t < limit; ++t) {
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == 0 && dLdhL ? dLdhLSet->at(e) : nullptr);
const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcLSet->at(e) : nullptr);
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc,
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr;
const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr;
const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr;
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
}
@ -1248,10 +1212,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// /** the objective is to provide math-readable code **/
// // equations (no peephole connections)
// // zi = x ^ Wxi + hI ^ Wri + bi
// // zf = x ^ Wxf + hI ^ Wrf + bf
// // zg = x ^ Wxg + hI ^ Wrg + bg
// // zo = x ^ Wxo + hI ^ Wro + bo
// // zi = x × Wxi + hI × Wri + bi
// // zf = x × Wxf + hI × Wrf + bf
// // zg = x × Wxg + hI × Wrg + bg
// // zo = x × Wxo + hI × Wro + bo
// // i = act(zi)
// // f = act(zf)
// // g = actC(zg)
@ -1260,10 +1224,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// // h = o * actH(c)
// // equations (peephole connections are present)
// // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi
// // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf
// // zg = x ^ Wxg + hI ^ Wrg + bg
// // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo
// // zi = x × Wxi + hI × Wri + cI * Wpi + bi
// // zf = x × Wxf + hI × Wrf + cI * Wpf + bf
// // zg = x × Wxg + hI × Wrg + bg
// // zo = x × Wxo + hI × Wro + c * Wpo + bo
// // i = act(zi)
// // f = act(zf)
// // g = actC(zg)
@ -1333,13 +1297,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
// // oFactor = *dLdh*dhdzo [bS, nOut]
// // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0;
// // tempIFE = dcdzi^WriT + dcdzf^WrfT + dcdzg^WrgT
// // tempO = dhdzo^WroT
// // tempIFE = dcdzi×WriT + dcdzf×WrfT + dcdzg×WrgT
// // tempO = dhdzo×WroT
// // dhIdcI = dhdc_from_previous_time_step
// // dLdx = iFactor^WxiT + fFactor^WxfT + eFactor^WxgT + oFactor^WxoT, [bS, nIn]
// // dLdhI = iFactor^WriT + fFactor^WrfT + eFactor^WrgT + oFactor^WroT, [bS, nOut]
// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn]
// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut]
// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut]
// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut]

View File

@ -42,7 +42,7 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra
//////////////////////////////////////////////////////////////////////////
void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const NDArray* dLdh, const NDArray* dLdc,
const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp);

View File

@ -2084,11 +2084,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) {
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // dLdh per each time step
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = true; // output at last time step
const auto retLastC = true; // cells state at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2097,6 +2097,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) {
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
@ -2113,12 +2114,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) {
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::SUM, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
@ -2131,63 +2132,6 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) {
const int nIn = 2;
const int nOut = 3;
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = false; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {4*nOut}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
cI.linspace(0.7,-0.1);
Wx.linspace(1,-0.1);
Wr.linspace(-1,0.1);
Wp.linspace(0.2,0.2);
b.linspace(1,-0.15);
std::vector<double> tArgs = {cellClip};
std::vector<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN);
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
const int sL = 3;
const int bS = 2;
const int nIn = 2;
const int nOut = 3;
const int dataFormat = 1; // [bS,sL,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
@ -2199,11 +2143,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = false; // output at last time step
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = false; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2233,13 +2177,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN);
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) {
const int sL = 4;
const int bS = 3;
@ -2258,10 +2202,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const auto retLastH = true; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2272,6 +2216,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
@ -2286,18 +2232,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) {
const int sL = 3;
const int bS = 2;
@ -2315,11 +2261,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // dLdh per each time step
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = true; // output at last time step
const auto retLastC = false; // cells state at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2328,7 +2274,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
@ -2343,18 +2291,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector<bool>(), {0., 1.}, GradCheck::LossFunc::MEAN, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) {
const int sL = 3;
const int bS = 2;
@ -2373,10 +2321,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const auto retLastH = true; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2387,6 +2335,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
@ -2401,18 +2351,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) {
const int sL = 3;
const int bS = 2;
@ -2430,11 +2380,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = false; // dLdh per each time step
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = true; // output at last time step
const auto retLastC = false; // cells state at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE);
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2444,7 +2394,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
@ -2459,18 +2411,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) {
const int sL = 3;
const int bS = 2;
@ -2489,10 +2447,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const auto retLastH = true; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2503,6 +2461,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
@ -2517,18 +2477,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
ASSERT_TRUE(isGradCorrect);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) {
TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) {
const int sL = 3;
const int bS = 2;
@ -2547,10 +2513,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) {
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // dLdh per each time step
const auto retLastH = false; // output at last time step
const auto retLastC = false; // cells state at last time step
const auto retLastH = true; // output at last time step
const auto retLastC = true; // cells state at last time step
const double cellClip = 0.5; // do not apply clipping
const double cellClip = 0.5; // clipping
NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE);
NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE);
@ -2561,6 +2527,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) {
NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE);
x.linspace(-2,0.1);
hI.linspace(-1.5,0.1);
@ -2575,12 +2543,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) {
std::vector<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs);
// const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs);
sd::ops::lstmLayer opFF;
sd::ops::lstmLayer_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true});
ASSERT_TRUE(isGradCorrect);
}

View File

@ -1904,6 +1904,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) {
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0});
ASSERT_TRUE(isGradCorrect);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
@ -1922,3 +1923,68 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
ASSERT_TRUE(isGradCorrect);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, gru_1) {
const int sL = 3;
const int bS = 2;
const int nIn = 5;
const int nOut = 4;
NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::FLOAT32);
NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::FLOAT32);
NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::FLOAT32);
NDArray b('c', {3*nOut}, sd::DataType::FLOAT32);
NDArray expH('c', {sL, bS, nOut}, {-1.681847, -1.062565, -0.443283, 0.175998,0.837823, 1.488041, 2.13826 , 2.788478, -0.888747, -0.491826, -0.094907, 0.302014,
0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, 0.135444, 0.397604,0.710558, 1.002922, 1.295287, 1.587651}, sd::DataType::FLOAT32);
Wx = 0.003;
Wh = 0.006;
b = 0.5;
NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE);
sd::ops::gru op;
auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto* h = results.at(0);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, gru_bp_1) {
const int sL = 3;
const int bS = 2;
const int nIn = 5;
const int nOut = 4;
NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE);
NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE);
NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE);
NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE);
NDArray b('c', {3*nOut}, sd::DataType::DOUBLE);
NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE);
Wx.linspace(1,-0.1);
Wh.linspace(0.2,0.2);
b.linspace(1,-0.15);
const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {});
const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {});
sd::ops::gru opFF;
sd::ops::gru_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
}