* - update gru ff op Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing gru_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - neglect dependencies between dLdh/dLdhLast/dLdcLast in lstmLayer backprop Signed-off-by: Yurii <iuriish@yahoo.com>
		
			
				
	
	
		
			339 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			339 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /*******************************************************************************
 | ||
|  * Copyright (c) 2020 Konduit K.K.
 | ||
|  *
 | ||
|  * This program and the accompanying materials are made available under the
 | ||
|  * terms of the Apache License, Version 2.0 which is available at
 | ||
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||
|  *
 | ||
|  * Unless required by applicable law or agreed to in writing, software
 | ||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | ||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | ||
|  * License for the specific language governing permissions and limitations
 | ||
|  * under the License.
 | ||
|  *
 | ||
|  * SPDX-License-Identifier: Apache-2.0
 | ||
|  ******************************************************************************/
 | ||
| 
 | ||
| //
 | ||
| // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||
| //
 | ||
| 
 | ||
| #include <system/op_boilerplate.h>
 | ||
| #if NOT_EXCLUDED(OP_lstmLayerCell)
 | ||
| 
 | ||
| #include <ops/declarable/CustomOperations.h>
 | ||
| #include<ops/declarable/helpers/lstmLayer.h>
 | ||
| 
 | ||
| namespace sd {
 | ||
| namespace ops {
 | ||
| 
 | ||
| 
 | ||
| //////////////////////////////////////////////////////////////////////////
 | ||
| CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) {
 | ||
| 
 | ||
|     // equations (no peephole connections)
 | ||
|     // it  = σ(Wxi * xt  +  Wri * ht-1  +  bi)
 | ||
|     // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  bf)
 | ||
|     // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
 | ||
|     // ct  = ft ◦ ct-1 + it ◦ c't
 | ||
|     // ot  = σ(Wxo * xt  +  Wro * ht-1  +  bo)
 | ||
|     // ht  = ot ◦ tanh(ct)
 | ||
| 
 | ||
|     // equations (peephole connections are present)
 | ||
|     // it  = σ(Wxi * xt  +  Wri * ht-1  +  Wpi ◦ ct-1  +  bi)
 | ||
|     // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  Wpf ◦ ct-1  +  bf)
 | ||
|     // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
 | ||
|     // ct  = clip(ft ◦ ct-1 + it ◦ c't)
 | ||
|     // ot  = σ(Wxo * xt  +  Wro * ht-1  +  Wpo ◦ ct  +  bo)
 | ||
|     // ht  = ot ◦ tanh(ct)
 | ||
| 
 | ||
|     // notations:
 | ||
|     // bS - batch size
 | ||
|     // nIn - input size
 | ||
|     // nOut - output size (hidden size)
 | ||
| 
 | ||
|     // INPUTS:
 | ||
|     // input x:                          [bS, nIn] or [nIn]
 | ||
|     // input weights Wx:                 [nIn, 4*nOut]
 | ||
|     // recurrent weights Wr:             [nOut, 4*nOut]
 | ||
|     // initial (previous) output hI:     [bS, nOut] or [nOut]
 | ||
|     // initial (previous) cell state cI: [bS, nOut] or [nOut]
 | ||
|     // biases b (optional):              [4*nOut]
 | ||
|     // peephole weights Wp (optional):   [3*nOut]
 | ||
| 
 | ||
|     // OUTPUTS:
 | ||
|     // current output h:     [bS, nOut] or [nOut]
 | ||
|     // current cell state c: [bS, nOut] or [nOut]
 | ||
| 
 | ||
|     // !!! dimension 4*nOut implies order it, ft, c't, ot
 | ||
|     // !!! dimension 3*nOut implies order it, ft, ot
 | ||
| 
 | ||
|     // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
 | ||
|     const auto gateAct       = INT_ARG(0);    // activation for input (i), forget (f) and output (o) gates
 | ||
|     const auto cellAct       = INT_ARG(1);    // activation for cell state (c)
 | ||
|     const auto outAct        = INT_ARG(2);    // activation for output (h)
 | ||
| 
 | ||
|     const auto hasBiases  = B_ARG(0);   // indicates whether biases array is provided
 | ||
|     const auto hasPH      = B_ARG(1);   // indicates whether peephole connections are present
 | ||
| 
 | ||
|     const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
 | ||
|     const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
 | ||
|     const auto outActHasAlpha  = outAct  == 3 || outAct  == 4 || outAct  == 5 || outAct  == 6 || outAct  == 8;
 | ||
|     const auto gateActHasBeta  = gateAct == 3 || gateAct == 6;
 | ||
|     const auto cellActHasBeta  = cellAct == 3 || cellAct == 6;
 | ||
|     const auto outActHasBeta   = outAct  == 3 || outAct  == 6;
 | ||
| 
 | ||
|     uint count = 1;
 | ||
|     const auto cellClip = T_ARG(0);                                     // cell clipping value, if it = 0 then do not apply clipping
 | ||
|     const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
 | ||
|     const auto gateBeta  = gateActHasBeta  ? T_ARG(count++) : 0;
 | ||
|     const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
 | ||
|     const auto cellBeta  = cellActHasBeta  ? T_ARG(count++) : 0;
 | ||
|     const auto outAlpha  = outActHasAlpha  ? T_ARG(count++) : 0;
 | ||
|     const auto outBeta   = outActHasBeta   ? T_ARG(count++) : 0;
 | ||
| 
 | ||
|     count = 3;
 | ||
|     const auto x  = INPUT_VARIABLE(0);                              // input
 | ||
|     const auto Wx = INPUT_VARIABLE(1);                              // input weights
 | ||
|     const auto Wr = INPUT_VARIABLE(2);                              // recurrent weights
 | ||
|     const auto b  = hasBiases ? INPUT_VARIABLE(count++) : nullptr;  // biases
 | ||
|     const auto hI = INPUT_VARIABLE(count++);                        // initial output
 | ||
|     const auto cI = INPUT_VARIABLE(count++);                        // initial cell state
 | ||
|     const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr;        // peephole weights
 | ||
| 
 | ||
|     REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !");
 | ||
| 
 | ||
|     auto h = OUTPUT_VARIABLE(0);
 | ||
|     auto c = OUTPUT_VARIABLE(1);
 | ||
| 
 | ||
|     // evaluate dimensions
 | ||
|     const Nd4jLong bS   = x->rankOf() == 1 ? 0 : x->sizeAt(0);
 | ||
|     const Nd4jLong nIn  = x->sizeAt(-1);
 | ||
|     const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
 | ||
| 
 | ||
|     // inputs validations
 | ||
|     // Wx validation
 | ||
|     if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
 | ||
|     // Wr validation
 | ||
|     if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
 | ||
|     // initial output/cell validation
 | ||
|     std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
 | ||
|     REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
 | ||
|     REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
 | ||
|     // biases validation
 | ||
|     if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
 | ||
|     // peephole weights validation
 | ||
|     if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
 | ||
| 
 | ||
|     std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
 | ||
|                                  static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
 | ||
|                                  static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
 | ||
|                                  static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
 | ||
| 
 | ||
|     helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c);
 | ||
| 
 | ||
|     return Status::OK();
 | ||
| }
 | ||
| 
 | ||
| DECLARE_TYPES(lstmLayerCell) {
 | ||
|     getOpDescriptor()
 | ||
|         ->setAllowedInputTypes(sd::DataType::ANY)
 | ||
|         ->setAllowedOutputTypes({ALL_FLOATS});
 | ||
| }
 | ||
| 
 | ||
| 
 | ||
| DECLARE_SHAPE_FN(lstmLayerCell) {
 | ||
| 
 | ||
|     const auto hasBiases  = B_ARG(0);   // indicates whether biases array is provided
 | ||
| 
 | ||
|     uint count = hasBiases ? 4 : 3;
 | ||
|     const auto hI = INPUT_VARIABLE(count++);        // initial output
 | ||
|     const auto cI = INPUT_VARIABLE(count);          // initial cell state
 | ||
| 
 | ||
|     return new ShapeList({hI->getShapeInfo(), cI->getShapeInfo()});
 | ||
| }
 | ||
| 
 | ||
| //////////////////////////////////////////////////////////////////////////
 | ||
| CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
 | ||
| 
 | ||
|     // equations (no peephole connections)
 | ||
|     // it  = σ(Wxi * xt  +  Wri * ht-1  +  bi)
 | ||
|     // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  bf)
 | ||
|     // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
 | ||
|     // ct  = ft ◦ ct-1 + it ◦ c't
 | ||
|     // ot  = σ(Wxo * xt  +  Wro * ht-1  +  bo)
 | ||
|     // ht  = ot ◦ tanh(ct)
 | ||
| 
 | ||
|     // equations (peephole connections are present)
 | ||
|     // it  = σ(Wxi * xt  +  Wri * ht-1  +  Wpi ◦ ct-1  +  bi)
 | ||
|     // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  Wpf ◦ ct-1  +  bf)
 | ||
|     // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
 | ||
|     // ct  = clip(ft ◦ ct-1 + it ◦ c't)
 | ||
|     // ot  = σ(Wxo * xt  +  Wro * ht-1  +  Wpo ◦ ct  +  bo)
 | ||
|     // ht  = ot ◦ tanh(ct)
 | ||
| 
 | ||
|     // notations:
 | ||
|     // bS - batch size
 | ||
|     // nIn - input size
 | ||
|     // nOut - output size (hidden size)
 | ||
| 
 | ||
|     // INPUTS:
 | ||
|     // input x:                          [bS, nIn] or [nIn]
 | ||
|     // input weights Wx:                 [nIn, 4*nOut]
 | ||
|     // recurrent weights Wr:             [nOut, 4*nOut]
 | ||
|     // initial (previous) output hI:     [bS, nOut] or [nOut]
 | ||
|     // initial (previous) cell state cI: [bS, nOut] or [nOut]
 | ||
|     // gradient wrt output dLdh:         [bS, nOut] or [nOut]
 | ||
|     // gradient wrt cell state dLdc:     [bS, nOut] or [nOut]
 | ||
|     // peephole weights Wp (optional):   [3*nOut]
 | ||
|     // biases b (optional):              [4*nOut]
 | ||
| 
 | ||
|     // OUTPUTS:
 | ||
|     // gradient wrt x dLdx:              [bS, nIn] or [nIn]
 | ||
|     // gradient wrt Wx dLdWx:            [nIn, 4*nOut]
 | ||
|     // gradient wrt Wr dLdWr:            [nOut, 4*nOut]
 | ||
|     // gradient wrt hI dLdhI:            [bS, nOut] or [nOut]
 | ||
|     // gradient wrt cI dLdcI:            [bS, nOut] or [nOut]
 | ||
|     // gradient wrt b dLdb (optional):   [4*nOut]
 | ||
|     // gradient wrt Wp dLdWp (optional): [3*nOut]
 | ||
| 
 | ||
| 
 | ||
|     // !!! dimension 4*nOut implies order it, ft, c't, ot
 | ||
|     // !!! dimension 3*nOut implies order it, ft, ot
 | ||
| 
 | ||
|     // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
 | ||
|     const auto gateAct       = INT_ARG(0);    // activation for input (i), forget (f) and output (o) gates
 | ||
|     const auto cellAct       = INT_ARG(1);    // activation for cell state (c)
 | ||
|     const auto outAct        = INT_ARG(2);    // activation for output (h)
 | ||
| 
 | ||
|     const auto hasBiases  = B_ARG(0);   // indicates whether biases array is provided
 | ||
|     const auto hasPH      = B_ARG(1);   // indicates whether peephole connections are present
 | ||
| 
 | ||
|     const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
 | ||
|     const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
 | ||
|     const auto outActHasAlpha  = outAct  == 3 || outAct  == 4 || outAct  == 5 || outAct  == 6 || outAct  == 8;
 | ||
|     const auto gateActHasBeta  = gateAct == 3 || gateAct == 6;
 | ||
|     const auto cellActHasBeta  = cellAct == 3 || cellAct == 6;
 | ||
|     const auto outActHasBeta   = outAct  == 3 || outAct  == 6;
 | ||
| 
 | ||
|     uint count = 1;
 | ||
|     const auto cellClip = T_ARG(0);                                     // cell clipping value, if it = 0 then do not apply clipping
 | ||
|     const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
 | ||
|     const auto gateBeta  = gateActHasBeta  ? T_ARG(count++) : 0;
 | ||
|     const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
 | ||
|     const auto cellBeta  = cellActHasBeta  ? T_ARG(count++) : 0;
 | ||
|     const auto outAlpha  = outActHasAlpha  ? T_ARG(count++) : 0;
 | ||
|     const auto outBeta   = outActHasBeta   ? T_ARG(count++) : 0;
 | ||
| 
 | ||
|     count = 3;
 | ||
|     const auto x  = INPUT_VARIABLE(0);                              // input
 | ||
|     const auto Wx = INPUT_VARIABLE(1);                              // input weights
 | ||
|     const auto Wr = INPUT_VARIABLE(2);                              // recurrent weights
 | ||
|     const auto b  = hasBiases ? INPUT_VARIABLE(count++) : nullptr;  // biases
 | ||
|     const auto hI = INPUT_VARIABLE(count++);                        // initial output
 | ||
|     const auto cI = INPUT_VARIABLE(count++);                        // initial cell state
 | ||
|     const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr;      // peephole weights
 | ||
|     const auto dLdh = INPUT_VARIABLE(count);                        // gradient wrt output
 | ||
| 
 | ||
|     REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !");
 | ||
| 
 | ||
|     count = 3;
 | ||
|     auto dLdx  = OUTPUT_VARIABLE(0);
 | ||
|     auto dLdWx = OUTPUT_VARIABLE(1);
 | ||
|     auto dLdWr = OUTPUT_VARIABLE(2);
 | ||
|     auto dLdb  = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr;
 | ||
|     auto dLdhI = OUTPUT_VARIABLE(count++);
 | ||
|     auto dLdcI = OUTPUT_VARIABLE(count++);
 | ||
|     auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr;
 | ||
| 
 | ||
|     // evaluate dimensions
 | ||
|     const Nd4jLong bS   = x->rankOf() == 1 ? 0 : x->sizeAt(0);
 | ||
|     const Nd4jLong nIn  = x->sizeAt(-1);
 | ||
|     const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
 | ||
| 
 | ||
|     // inputs validations
 | ||
|     // Wx validation
 | ||
|     if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
 | ||
|     // Wr validation
 | ||
|     if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
 | ||
|     // initial output/cell validation
 | ||
|     std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
 | ||
|     REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
 | ||
|     REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
 | ||
|     REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
 | ||
|     // biases validation
 | ||
|     if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
 | ||
|     if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut))
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str());
 | ||
|     // peephole weights validation
 | ||
|     if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
 | ||
|     if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut))
 | ||
|         REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str());
 | ||
| 
 | ||
| 
 | ||
|     std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
 | ||
|                                  static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
 | ||
|                                  static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
 | ||
|                                  static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
 | ||
| 
 | ||
|     std::vector<Nd4jLong> zShape = x->rankOf() == 1 ? std::vector<Nd4jLong>({4*nOut}) : std::vector<Nd4jLong>({bS, 4*nOut});
 | ||
| 
 | ||
|     NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext());
 | ||
|     NDArray a = z.ulike();
 | ||
|     NDArray h = cI->ulike();
 | ||
|     NDArray c = cI->ulike();
 | ||
| 
 | ||
|     helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);
 | ||
| 
 | ||
|     helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
 | ||
| 
 | ||
|     return Status::OK();
 | ||
| }
 | ||
| 
 | ||
| DECLARE_TYPES(lstmLayerCellBp) {
 | ||
|     getOpDescriptor()
 | ||
|         ->setAllowedInputTypes(sd::DataType::ANY)
 | ||
|         ->setAllowedOutputTypes({ALL_FLOATS});
 | ||
| }
 | ||
| 
 | ||
| 
 | ||
| DECLARE_SHAPE_FN(lstmLayerCellBp) {
 | ||
| 
 | ||
|     const auto hasBiases  = B_ARG(0);   // indicates whether biases array is provided
 | ||
|     const auto hasPH      = B_ARG(1);   // indicates whether peephole connections are present
 | ||
| 
 | ||
|     uint count = 3;
 | ||
|     const auto x  = INPUT_VARIABLE(0);                              // input
 | ||
|     const auto Wx = INPUT_VARIABLE(1);                              // input weights
 | ||
|     const auto Wr = INPUT_VARIABLE(2);                              // recurrent weights
 | ||
|     const auto b  = hasBiases ? INPUT_VARIABLE(count++) : nullptr;  // biases
 | ||
|     const auto hI = INPUT_VARIABLE(count++);                        // initial output
 | ||
|     const auto cI = INPUT_VARIABLE(count++);                        // initial cell state
 | ||
|     const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr;        // peephole weights
 | ||
| 
 | ||
|     std::vector<Nd4jLong*> shapes = {x->getShapeInfo(), Wx->getShapeInfo(), Wr->getShapeInfo()};
 | ||
| 
 | ||
|     if(b != nullptr)
 | ||
|         shapes.push_back(b->getShapeInfo());
 | ||
| 
 | ||
|     shapes.push_back(hI->getShapeInfo());
 | ||
|     shapes.push_back(cI->getShapeInfo());
 | ||
| 
 | ||
|     if(Wp != nullptr)
 | ||
|         shapes.push_back(Wp->getShapeInfo());
 | ||
| 
 | ||
|     return new ShapeList(shapes);
 | ||
| }
 | ||
| 
 | ||
| }
 | ||
| }
 | ||
| 
 | ||
| #endif |