* initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored buffer() and shapeInfo() methods usage with NDArray class. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt Graph class methods to use const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt choose op to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt where op shape method to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt lstsq op to use constant empty shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt matrix_diag_part op shape routine to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt determinant ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt mean_pairwssqerr_loss ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for loss ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt log_loss op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt dilation2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted deconv2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted dynamicRNN op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for lstm layer ops. Signed-off-by: shugeo <sgazeos@gmail.com> * few updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * first cuda tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Adopt constant shapes for sconv2d ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes for gru ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes with shape methods for segment ops and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with unsorted_segment_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with gamma op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods of reduce_stddev ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for reduce_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape method for squeeze op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt strided_slice shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored concat op shape method to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape method for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted split op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted tile ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Added const cast for mkldnn routines handles. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetic changes to proper usage of constant pointers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored depthToSpace helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored histogram helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored im2col helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gather and gatherND helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage on percentile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed gather shape with helpers and range buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with space to depth helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage and constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with LUP decomposition> Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored onehot_ helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pad and prefix to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactoed softmax helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed space to batch helpers to use buffers properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed stack and split helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with sparse to dense helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with mindistance_ helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with tile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with legacy pairwise bool ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple of methods to adopt constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed broadcasting with constant shape." Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const usage with inplace reverse and constant shapes with legacy reduction. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored sort to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected sort for constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with special methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored Context to conform with constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * CUDA broadcasting headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * pairwise/indexreduce/random headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored native ops to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * legacy reduce3/scalar headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected pullRow signature and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected routines to proper use of constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with NDArray tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed native ops tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed special concat routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with test. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with a test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored TAD.h and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored calcStrides* routines to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed miscelaneous errors with constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected definitions for declared functions. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed const shapes with shape routines. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed shape method for broadcastable case. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * xw_plus_b BP shape fn restored Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed signatures with broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Repaired backprops shape methods for a set of operations. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored broadcast bool for cuda. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods for 3 args with const qualifier. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed a couple of kernel signatures for broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels signatures for const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise methods to persistent buffers and shapes usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with scalar kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored indexreduce kernels signatures to use const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise bool kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored random special ops to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored native ops to conform with const shapes and buffers under cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetical changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes and buffers error. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected start pos routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored helpers to use proper methods instead. Signed-off-by: shugeo <sgazeos@gmail.com> * bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected const shape cases with sort and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes for sort. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored kernel declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernel declarations to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed segment helpers kernels declarations and so on to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with segment and solve helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernel declaration with adjustWeight helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed cuda implementations for constant shape helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted const shape usage with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted top_k kernels to use const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernels declarations to adopt const shapes with helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored NDArray definitions to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes with image suppression helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Slight improvement with buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with definitions. Signed-off-by: shugeo <sgazeos@gmail.com> * minor updates on cpu side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored const shape usage with ConstantDescritor and native ops with cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tear and tile kernels to adopt with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * softmax_loop fix Signed-off-by: raver119 <raver119@gmail.com> * update missing signature Signed-off-by: raver119@gmail.com <raver119@gmail.com> * softmax again Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more missing consts Signed-off-by: raver119 <raver119@gmail.com> * new methods updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
339 lines
18 KiB
C++
339 lines
18 KiB
C++
/*******************************************************************************
|
||
* Copyright (c) 2020 Konduit K.K.
|
||
*
|
||
* This program and the accompanying materials are made available under the
|
||
* terms of the Apache License, Version 2.0 which is available at
|
||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
*
|
||
* Unless required by applicable law or agreed to in writing, software
|
||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
* License for the specific language governing permissions and limitations
|
||
* under the License.
|
||
*
|
||
* SPDX-License-Identifier: Apache-2.0
|
||
******************************************************************************/
|
||
|
||
//
|
||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||
//
|
||
|
||
#include <system/op_boilerplate.h>
|
||
#if NOT_EXCLUDED(OP_lstmLayerCell)
|
||
|
||
#include <ops/declarable/CustomOperations.h>
|
||
#include<ops/declarable/helpers/lstmLayer.h>
|
||
|
||
namespace sd {
|
||
namespace ops {
|
||
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
CUSTOM_OP_IMPL(lstmLayerCell, 5, 2, false, 1, 3) {
|
||
|
||
// equations (no peephole connections)
|
||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||
// ct = ft ◦ ct-1 + it ◦ c't
|
||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||
// ht = ot ◦ tanh(ct)
|
||
|
||
// equations (peephole connections are present)
|
||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||
// ht = ot ◦ tanh(ct)
|
||
|
||
// notations:
|
||
// bS - batch size
|
||
// nIn - input size
|
||
// nOut - output size (hidden size)
|
||
|
||
// INPUTS:
|
||
// input x: [bS, nIn] or [nIn]
|
||
// input weights Wx: [nIn, 4*nOut]
|
||
// recurrent weights Wr: [nOut, 4*nOut]
|
||
// initial (previous) output hI: [bS, nOut] or [nOut]
|
||
// initial (previous) cell state cI: [bS, nOut] or [nOut]
|
||
// biases b (optional): [4*nOut]
|
||
// peephole weights Wp (optional): [3*nOut]
|
||
|
||
// OUTPUTS:
|
||
// current output h: [bS, nOut] or [nOut]
|
||
// current cell state c: [bS, nOut] or [nOut]
|
||
|
||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||
// !!! dimension 3*nOut implies order it, ft, ot
|
||
|
||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
|
||
const auto cellAct = INT_ARG(1); // activation for cell state (c)
|
||
const auto outAct = INT_ARG(2); // activation for output (h)
|
||
|
||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||
|
||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||
|
||
uint count = 1;
|
||
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||
|
||
count = 3;
|
||
const auto x = INPUT_VARIABLE(0); // input
|
||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
|
||
|
||
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL operation: cell clipping value should be nonnegative (>=0) !");
|
||
|
||
auto h = OUTPUT_VARIABLE(0);
|
||
auto c = OUTPUT_VARIABLE(1);
|
||
|
||
// evaluate dimensions
|
||
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
|
||
const Nd4jLong nIn = x->sizeAt(-1);
|
||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||
|
||
// inputs validations
|
||
// Wx validation
|
||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||
// Wr validation
|
||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||
// initial output/cell validation
|
||
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
|
||
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||
// biases validation
|
||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||
// peephole weights validation
|
||
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||
|
||
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
|
||
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||
|
||
helpers::lstmLayerCell(x, Wx, Wr, b, hI, cI, Wp, params, h, c);
|
||
|
||
return Status::OK();
|
||
}
|
||
|
||
DECLARE_TYPES(lstmLayerCell) {
|
||
getOpDescriptor()
|
||
->setAllowedInputTypes(sd::DataType::ANY)
|
||
->setAllowedOutputTypes({ALL_FLOATS});
|
||
}
|
||
|
||
|
||
DECLARE_SHAPE_FN(lstmLayerCell) {
|
||
|
||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||
|
||
uint count = hasBiases ? 4 : 3;
|
||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||
const auto cI = INPUT_VARIABLE(count); // initial cell state
|
||
|
||
return new ShapeList({hI->shapeInfo(), cI->shapeInfo()});
|
||
}
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) {
|
||
|
||
// equations (no peephole connections)
|
||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||
// ct = ft ◦ ct-1 + it ◦ c't
|
||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||
// ht = ot ◦ tanh(ct)
|
||
|
||
// equations (peephole connections are present)
|
||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||
// ct = clip(ft ◦ ct-1 + it ◦ c't)
|
||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||
// ht = ot ◦ tanh(ct)
|
||
|
||
// notations:
|
||
// bS - batch size
|
||
// nIn - input size
|
||
// nOut - output size (hidden size)
|
||
|
||
// INPUTS:
|
||
// input x: [bS, nIn] or [nIn]
|
||
// input weights Wx: [nIn, 4*nOut]
|
||
// recurrent weights Wr: [nOut, 4*nOut]
|
||
// initial (previous) output hI: [bS, nOut] or [nOut]
|
||
// initial (previous) cell state cI: [bS, nOut] or [nOut]
|
||
// gradient wrt output dLdh: [bS, nOut] or [nOut]
|
||
// gradient wrt cell state dLdc: [bS, nOut] or [nOut]
|
||
// peephole weights Wp (optional): [3*nOut]
|
||
// biases b (optional): [4*nOut]
|
||
|
||
// OUTPUTS:
|
||
// gradient wrt x dLdx: [bS, nIn] or [nIn]
|
||
// gradient wrt Wx dLdWx: [nIn, 4*nOut]
|
||
// gradient wrt Wr dLdWr: [nOut, 4*nOut]
|
||
// gradient wrt hI dLdhI: [bS, nOut] or [nOut]
|
||
// gradient wrt cI dLdcI: [bS, nOut] or [nOut]
|
||
// gradient wrt b dLdb (optional): [4*nOut]
|
||
// gradient wrt Wp dLdWp (optional): [3*nOut]
|
||
|
||
|
||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||
// !!! dimension 3*nOut implies order it, ft, ot
|
||
|
||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||
const auto gateAct = INT_ARG(0); // activation for input (i), forget (f) and output (o) gates
|
||
const auto cellAct = INT_ARG(1); // activation for cell state (c)
|
||
const auto outAct = INT_ARG(2); // activation for output (h)
|
||
|
||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||
|
||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||
|
||
uint count = 1;
|
||
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||
|
||
count = 3;
|
||
const auto x = INPUT_VARIABLE(0); // input
|
||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||
const auto dLdh = INPUT_VARIABLE(count); // gradient wrt output
|
||
|
||
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER_CELL_BP operation: cell clipping value should be nonnegative (>=0) !");
|
||
|
||
count = 3;
|
||
auto dLdx = OUTPUT_VARIABLE(0);
|
||
auto dLdWx = OUTPUT_VARIABLE(1);
|
||
auto dLdWr = OUTPUT_VARIABLE(2);
|
||
auto dLdb = hasBiases ? OUTPUT_VARIABLE(count++) : nullptr;
|
||
auto dLdhI = OUTPUT_VARIABLE(count++);
|
||
auto dLdcI = OUTPUT_VARIABLE(count++);
|
||
auto dLdWp = hasPH ? OUTPUT_VARIABLE(count) : nullptr;
|
||
|
||
// evaluate dimensions
|
||
const Nd4jLong bS = x->rankOf() == 1 ? 0 : x->sizeAt(0);
|
||
const Nd4jLong nIn = x->sizeAt(-1);
|
||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||
|
||
// inputs validations
|
||
// Wx validation
|
||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||
// Wr validation
|
||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
|
||
// initial output/cell validation
|
||
std::vector<Nd4jLong> exphIcIShape = x->rankOf() == 1 ? std::vector<Nd4jLong>{nOut} : std::vector<Nd4jLong>{bS, nOut};
|
||
REQUIRE_TRUE(hI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(hI).c_str());
|
||
REQUIRE_TRUE(cI->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(cI).c_str());
|
||
REQUIRE_TRUE(dLdh->isSameShape(exphIcIShape), 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdh gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(exphIcIShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
|
||
// biases validation
|
||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||
if(dLdb != nullptr && (dLdb->rankOf() != 1 || dLdb->sizeAt(0) != 4*nOut))
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdb gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(dLdb).c_str());
|
||
// peephole weights validation
|
||
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||
if(dLdWp != nullptr && (dLdWp->rankOf() != 1 || dLdWp->sizeAt(0) != 3*nOut))
|
||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_CELL_BP operation: wrong shape of dLdWp gradient, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(dLdWp).c_str());
|
||
|
||
|
||
std::vector<float> params = {static_cast<float>(0)/*ignore*/, static_cast<float>(0)/*ignore*/, static_cast<float>(cellClip),
|
||
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||
|
||
std::vector<Nd4jLong> zShape = x->rankOf() == 1 ? std::vector<Nd4jLong>({4*nOut}) : std::vector<Nd4jLong>({bS, 4*nOut});
|
||
|
||
NDArray z(x->ordering(), zShape, x->dataType(), block.launchContext());
|
||
NDArray a = z.ulike();
|
||
NDArray h = cI->ulike();
|
||
NDArray c = cI->ulike();
|
||
|
||
helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c);
|
||
|
||
helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp);
|
||
|
||
return Status::OK();
|
||
}
|
||
|
||
DECLARE_TYPES(lstmLayerCellBp) {
|
||
getOpDescriptor()
|
||
->setAllowedInputTypes(sd::DataType::ANY)
|
||
->setAllowedOutputTypes({ALL_FLOATS});
|
||
}
|
||
|
||
|
||
DECLARE_SHAPE_FN(lstmLayerCellBp) {
|
||
|
||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||
const auto hasPH = B_ARG(1); // indicates whether peephole connections are present
|
||
|
||
uint count = 3;
|
||
const auto x = INPUT_VARIABLE(0); // input
|
||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||
const auto hI = INPUT_VARIABLE(count++); // initial output
|
||
const auto cI = INPUT_VARIABLE(count++); // initial cell state
|
||
const auto Wp = hasPH ? INPUT_VARIABLE(count) : nullptr; // peephole weights
|
||
|
||
auto shapes = SHAPELIST(x->shapeInfo(), Wx->shapeInfo(), Wr->shapeInfo());
|
||
|
||
if(b != nullptr)
|
||
shapes->push_back(b->shapeInfo());
|
||
|
||
shapes->push_back(hI->shapeInfo());
|
||
shapes->push_back(cI->shapeInfo());
|
||
|
||
if(Wp != nullptr)
|
||
shapes->push_back(Wp->shapeInfo());
|
||
|
||
return shapes;
|
||
}
|
||
|
||
}
|
||
}
|
||
|
||
#endif |