cavis/libnd4j/include/ops/declarable/generic/recurrent/sru.cpp

973 lines
58 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017
//
//@author Yurii Shyrma
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_sru)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/sru.h>
#include <MmulHelper.h>
#include <helpers/PointersManager.h>
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [2*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
auto mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
auto h = OUTPUT_VARIABLE(0); // cell outputs, [bS x inSize x time]
auto c = OUTPUT_VARIABLE(1); // cell states, [bS x inSize x time]
const int rank = x->rankOf(); // = 3
const auto bS = x->sizeAt(0);
const auto inSize = x->sizeAt(1);
const auto time = x->sizeAt(2);
// input shapes validation
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
// xm = x * mask
auto xm = x;
if(mask) {
xm = new NDArray(x->getShapeInfo(), true, block.launchContext());
x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xm, nullptr);
}
// time loop
helpers::sruTimeLoop(block.launchContext(), xm, c0, w, b, h, c);
if(mask)
delete xm;
return Status::OK();
}
DECLARE_TYPES(sru) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(sru) {
auto xShapeInfo = inputShape->at(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
auto wShapeInfo = inputShape->at(1); // W, 2d tensor of weights [3*inSize x inSize]
auto bShapeInfo = inputShape->at(2); // B, row of biases with twice length [2*inSize]
auto c0ShapeInfo = inputShape->at(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
const int rank = xShapeInfo[0]; // = 3
const int bS = xShapeInfo[1];
const int inSize = xShapeInfo[2];
const int time = xShapeInfo[3];
// input shapes validation
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU operation: wrong rank of biases array, expected is %i, but got %i instead !", 1, bShapeInfo[0]);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]);
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
Nd4jLong* newShapeInfo1 = nullptr;
ALLOCATE(newShapeInfo1, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time]
newShapeInfo1[0] = rank;
newShapeInfo1[1] = bS;
newShapeInfo1[2] = inSize;
newShapeInfo1[3] = time;
ShapeUtils::updateStridesAndType(newShapeInfo1, xShapeInfo, shape::order(xShapeInfo));
ShapeDescriptor descriptor(newShapeInfo1);
RELEASE(newShapeInfo1, block.getWorkspace());
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor);
return SHAPELIST(result, result);
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*K]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
auto c = INPUT_VARIABLE(4); // C, [bS x K x N]
auto inGradCt = INPUT_VARIABLE(5); // [bS x K]
auto inGradH = INPUT_VARIABLE(6); // [bS x K x N]
NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
bool applyMask = false;
if (block.width() > 7) {
mask = INPUT_VARIABLE(7);
applyMask = true;
}
auto gradX = OUTPUT_VARIABLE(0); // [bS x K x N]
auto gradW = OUTPUT_VARIABLE(1); // [bS x 3K x K]
auto gradB = OUTPUT_VARIABLE(2); // [1 x 2K]
auto gradInit = OUTPUT_VARIABLE(3); // [bS x K]
const int bS = x->shapeOf()[0];
const int K = x->shapeOf()[1];
const int N = x->shapeOf()[2]; // N - number of time steps
auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext());
auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext());
auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext());
auto gct = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto gradTanh = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto gradCt = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto ftMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto rtMinus = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto temp1 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
auto temp2 = NDArrayFactory::create_(c->ordering(), {bS, K}, gradX->dataType(), block.launchContext());
// x = x * mask
if(applyMask)
x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask
// multiplication matrix wi = matmul(w,x), U = WX
auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x N]
auto wiZ = (*wi)({0,0, 0,K, 0,0}, true); // [bS x K x N]
auto wiF = (*wi)({0,0, K,2*K, 0,0}, true); // forget gate [bS x K x N]
auto wiR = (*wi)({0,0, 2*K,3*K, 0,0}, true); // reset gate [bS x K x N]
auto bF = (*b) ({0,0, 0,K }, true); // biases for forget gate [1 x K]
auto bR = (*b) ({0,0, K,2*K}, true); // biases for reset gate [1 x K]
auto gradBF = (*gradBias)({0,0, 0,K, 0,0}, true); // [bS x K x N]
auto gradBR = (*gradBias)({0,0, K,2*K, 0,0}, true); // [bS x K x N]
auto gradUZ = (*gradU) ({0,0, 0,K, 0,0}, true ); // [bS x K x N]
auto gradUF = (*gradU) ({0,0, K,2*K, 0,0}, true ); // [bS x K x N]
auto gradUR = (*gradU) ({0,0, 2*K,3*K, 0,0}, true ); // [bS x K x N]
NDArray* ct_1 = nullptr;
std::vector<Nd4jLong> idx = {0,0, 0,0, 0,0};
for (int t = N-1; t >=0 ; --t) {
// initialization
idx[4] = t;
idx[5] = t + 1;
auto xt = (*x)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto zt = wiZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto ft = wiF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto rt = wiR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto ct = (*c)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto inGradHt = (*inGradH)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradBRt = gradBR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradBFt = gradBF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradHXt = (*gradHX)(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradUZt = gradUZ(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradUFt = gradUF(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
auto gradURt = gradUR(idx); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
if(t != 0) {
idx[4] = t - 1;
idx[5] = t;
ct_1 = new NDArray((*c)(idx)); // previous c_{t-1}
}
else
ct_1 = c0;
///////////////// forward
// ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
ft.addRowVector(&bF, &ft);
rt.addRowVector(&bR, &rt);
ft.applyTransform(transform::Sigmoid, nullptr, nullptr);
rt.applyTransform(transform::Sigmoid, nullptr, nullptr);
// TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
ct.applyTransform(transform::Tanh, gct);
// ftMinus = 1-ft, rtMinus = 1-rt
ft.applyTransform(transform::OneMinus, ftMinus);
rt.applyTransform(transform::OneMinus, rtMinus);
///////////////// backward
// bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt)
rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt;
temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt;
inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
// bF, TODO - tanh
// gradTanh = (1.0f - g_ct * g_ct);
gct->applyPairwiseTransform(pairwise::Multiply, gct, gradTanh, nullptr); // gradTanh = g_ct * g_ct
gradTanh->applyTransform(transform::OneMinus, gradTanh); // gradTanh = (1.0f - g_ct * g_ct)
// gradCt = inGradHt * rt * gradTanh
rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, gradCt, nullptr); // gradCt = rt * gradTanh
inGradHt.applyPairwiseTransform(pairwise::Multiply, gradCt, gradCt, nullptr); // gradCt = inGradHt * rt * gradTanh
// gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt)
ct_1->applyPairwiseTransform(pairwise::Subtract, &zt, temp2, nullptr); // temp2 = (ct_1 - zt)
temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft)
temp1->applyPairwiseTransform(pairwise::Multiply, &ft, temp1, nullptr); // temp1 = (gradCt + inGradCt)*(1-ft)*ft
temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
// x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr);
// U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh
inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh
temp1->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = inGradHt * rt * grad_tanh + inGradCt
temp1->applyPairwiseTransform(pairwise::Multiply, ftMinus, &gradUZt, nullptr); // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
gradUFt.assign(&gradBFt);
gradURt.assign(&gradBRt);
// c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
gradCt->applyPairwiseTransform(pairwise::Add, inGradCt, temp1, nullptr); // temp1 = (gradCt + inGradCt)
temp1->applyPairwiseTransform(pairwise::Multiply, &ft, inGradCt, nullptr); // inGradCt = (gradCt + inGradCt) * ft;
if(t != 0)
delete ct_1;
}
// gradInit
gradInit->assign(inGradCt);
// gradX
auto weightsT = w->transpose(); // [K x 3K]
MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N]
gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x
if(applyMask)
gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
// gradB
auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K]
gradB->assign(temp3);
// gradW [bS x 3K x K]
x->permutei({0, 2, 1}); // [bS x N x K]
MmulHelper::mmul(gradU, x, gradW, 1., 0.); // [bS x 3K x K]
delete gct; delete gradU; delete gradHX;
delete temp1; delete temp2; delete temp3; delete gradCt; delete wi;
delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias;
return Status::OK();
}
DECLARE_TYPES(sru_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(sru_bp) {
auto inShape = inputShape->at(0); // [bS x inSize x time]
auto bS = inShape[1];
auto inSize = inShape[2];
auto time = inShape[3];
char order = (char)(inShape[9]);
ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time});
ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize});
ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize});
ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 4*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize]
auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize]
// input shapes validation
const int rank = x->rankOf();
const Nd4jLong bS = x->sizeAt(1);
const Nd4jLong inSize = x->sizeAt(2) / 2;
REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct);
return Status::OK();
}
DECLARE_TYPES(sru_bi) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(sru_bi) {
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
auto wShapeInfo = inputShape->at(1);
auto bShapeInfo = inputShape->at(2);
auto c0ShapeInfo = inputShape->at(3);
Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
const int rank = xShapeInfo[0]; // = 3
const Nd4jLong time = xShapeInfo[1];
const Nd4jLong bS = xShapeInfo[2];
const Nd4jLong inSize = xShapeInfo[3] / 2;
// input shapes validation
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", rank-1, bShapeInfo[0]);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]);
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
char order = shape::order(xShapeInfo);
ShapeDescriptor descriptor(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize});
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(descriptor);
return SHAPELIST(result, result);
}
DECLARE_TYPES(sru_bi_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 4*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize]
auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize]
auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize]
NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
// input shapes validation
const int rank = x->rankOf();
const Nd4jLong time = x->sizeAt(0);
const Nd4jLong bS = x->sizeAt(1);
const Nd4jLong inSize = x->sizeAt(2) / 2;
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf());
REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf());
REQUIRE_TRUE(inGradHt->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHt->rankOf());
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ct);
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize]
auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize]
auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize]
auto gradC0 = OUTPUT_VARIABLE(3); // [bS x 2*inSize]
helpers::sruBIBP(block.launchContext(), x, w, b, c0, ct, inGradC0, inGradHt, mask, gradI, gradW, gradB, gradC0);
return Status::OK();
}
DECLARE_SHAPE_FN(sru_bi_bp) {
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
auto wShapeInfo = inputShape->at(1);
auto bShapeInfo = inputShape->at(2);
auto c0ShapeInfo = inputShape->at(3);
auto ctShapeInfo = inputShape->at(4);
auto inGradC0ShapeInfo = inputShape->at(5);
auto inGradHtShapeInfo = inputShape->at(6);
Nd4jLong* maskShapeInfo = block.width() > 7 ? inputShape->at(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
// input shapes validation
const int rank = xShapeInfo[0];
const Nd4jLong time = xShapeInfo[1];
const Nd4jLong bS = xShapeInfo[2];
const Nd4jLong inSize = xShapeInfo[3] / 2;
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", bShapeInfo);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo);
REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo);
REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]);
REQUIRE_TRUE(inGradHtShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead !", rank, inGradHtShapeInfo[0]);
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo);
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
const std::string inGradC0Shape = ShapeUtils::shapeAsString(inGradC0ShapeInfo);
const std::string inGradC0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string inGradHtShape = ShapeUtils::shapeAsString(inGradHtShapeInfo);
const std::string inGradHtCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
REQUIRE_TRUE(inGradC0Shape == inGradC0CorrectShape, 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", inGradC0CorrectShape.c_str(), inGradC0Shape.c_str());
REQUIRE_TRUE(inGradHtShape == inGradHtCorrectShape, 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", inGradHtCorrectShape.c_str(), inGradHtShape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
const char order = shape::order(xShapeInfo);
ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize});
ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize});
ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {1, 4 * inSize});
ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize});
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
}
}
}
#endif
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
*
* Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 × 2K]
* 3: 2d tensor of previous cell state [bS x K]
* 4: optional, 2d tensor of dropout mask [bS x K]
*
* Output arrays:
* 0: 3d tensor of cell output [bS x K x N]
* 1: 3d tensor of cell state [bS x K x N]
*/
// #if NOT_EXCLUDED(OP_sru)
// DECLARE_CUSTOM_OP(sru_old, 5, 2, false, 0, 0);
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
*
* Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 × 2K]
* 3: 2d tensor of previous cell state [bS x K]
* 4: optional, 2d tensor of dropout mask [bS x K]
*
* Output arrays:
* 0: 3d tensor of cell output [bS x K x N]
* 1: 3d tensor of cell state [bS x K x N]
*/
// #if NOT_EXCLUDED(OP_sru_logic)
// DECLARE_CUSTOM_OP(sru_logic, 5, 2, false, 0, 0);
// #endif
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
*
* Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 × 2K]
* 3: 2d tensor of previous cell state [bS x K]
* 4: 3d tensor of cell state [bS x K x N]
* 5: 2d tensor of cell state gradients [bS x K]
* 6: 3d tensor of state output gradients [bS x K x N]
* 7: optional, 2d tensor of dropout mask [bS x K]
*
* Output arrays:
* 0: 3d tensor of input gradients [bS x K x N]
* 1: 3d tensor of weights gradients [bS x 3K x K]
* 2: 2d, row of biases gradients [1 x 2K]
* 3: 2d, tensor of state gradients [bS x K]
*/
// #if NOT_EXCLUDED(OP_sru_logic)
// DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0);
// #endif
// return 2d array evaluated though last dimension interval t1-t2
// static NDArray* timestep(const NDArray* arr, const int t1, const int t2) {
// NDArray* result = new NDArray((*arr)({0,0, 0,0, t1,t2}, true));
// result->reshapei(result->ordering(), {arr->shapeOf()[0], arr->shapeOf()[1]} );
// return result;
// }
/////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) {
// auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
// auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
// auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*K]
// auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
// bool applyMask = false;
// if (block.width() > 4) {
// mask = INPUT_VARIABLE(4);
// applyMask = true;
// }
// auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N]
// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N]
// const int bS = input->shapeOf()[0]; // bS - batch size
// const int K = input->shapeOf()[1]; // K - number of features
// const int N = input->shapeOf()[2]; // N - number of time steps
// const auto wi = mmul(*weights, *input); // U [bS x 3K x N]
// const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K]
// const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K]
// NDArray xt(input->dataType(), block.launchContext());
// NDArray zt(input->dataType(), block.launchContext());
// NDArray ft(input->dataType(), block.launchContext());
// NDArray rt(input->dataType(), block.launchContext());
// NDArray ht(input->dataType(), block.launchContext());
// NDArray ct = *init;
// NDArray gct(state->ordering(), {bS, K}, input->dataType(), block.launchContext());
// NDArray xmt = *input;
// // input = input * mask
// if(applyMask)
// xmt.applyBroadcast(broadcast::Multiply, {0, 1}, mask, &xmt, nullptr);
// for (int t = 0; t < N; ++t) {
// xt = xmt({0,0, 0,0, t,t+1}); xt.reshapei(xt.ordering(), {bS, K}); // [bS x K x N] -> [bS x K x 1] -> [bS x K]
// zt = wi({0,0, 0, K, t,t+1}); zt.reshapei(zt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
// ft = wi({0,0, K, 2*K, t,t+1}); ft.reshapei(ft.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
// rt = wi({0,0, 2*K,3*K, t,t+1}); rt.reshapei(rt.ordering(), {bS, K}); // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
// ft = sigmoid_(ft + bF);
// rt = sigmoid_(rt + bR);
// ct = ft * (ct - zt) + zt;
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct.applyTransform(transform::Tanh, &gct);
// ht = rt * (gct - xt) + xt;
// // save results
// (*output)({0,0, 0,0, t,t+1}, true).assign(ht);
// (*state)({0,0, 0,0, t,t+1}, true).assign(ct);
// }
// return Status::OK();
// }
// DECLARE_TYPES(sru_logic) {
// getOpDescriptor()
// ->setAllowedInputTypes(nd4j::DataType::ANY)
// ->setAllowedOutputTypes({ALL_FLOATS});
// }
// DECLARE_SHAPE_FN(sru_logic) {
// auto inShape = inputShape->at(0); // [bS x K x N]
// int rank = inShape[0]; // = 3
// int size = rank*2 + 4;
// int bS = inShape[1];
// int K = inShape[2];
// int N = inShape[3];
// char order = (char)(inShape[size-1]);
// Nd4jLong* newShapeInfo1 = nullptr;
// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
// newShapeInfo1[0] = rank;
// newShapeInfo1[1] = bS;
// newShapeInfo1[2] = K;
// newShapeInfo1[3] = N;
// ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
// auto result = CONSTANT(newShapeInfo1);
// return SHAPELIST(result, result);
// }
// //////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(sru_old, 5, 2, false, 0, 0) {
// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x inSize]
// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*inSize]
// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
// bool applyMask = false;
// if (block.width() > 4) {
// mask = INPUT_VARIABLE(4);
// applyMask = true;
// }
// auto h = OUTPUT_VARIABLE(0); // h_t, [bS x inSize x time]
// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x inSize x time]
// const int bS = x->shapeOf()[0]; // bS - batch size
// const int inSize = x->shapeOf()[1]; // inSize - number of features
// const int time = x->shapeOf()[2]; // time - number of time steps
// // multiplication matrix = matmul(w,x)
// auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.); // U [bS x 3K x time]
// auto wiZ = (*wi)({0,0, 0,inSize, 0,0}, true); // [bS x inSize x time]
// auto wiF = (*wi)({0,0, inSize,2*inSize, 0,0}, true); // forget gate [bS x inSize x time]
// auto wiR = (*wi)({0,0, 2*inSize,3*inSize, 0,0}, true); // reset gate [bS x inSize x time]
// auto bF = (*b) ({0,0, 0,inSize }, true); // biases for forget gate [1 x inSize]
// auto bR = (*b) ({0,0, inSize,2*inSize}, true); // biases for reset gate [1 x inSize]
// NDArray* xt(nullptr), *zt(nullptr), *ft(nullptr), *rt(nullptr), *ct(nullptr), *ht(nullptr);
// auto ct_1 = c0->dup(c0->ordering());
// auto gct = NDArrayFactory::create_(state->ordering(), {bS, inSize}, state->dataType(), state->getContext());
// auto xmt = x->dup(x->ordering());
// // x = x * mask
// if(applyMask)
// xmt->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xmt, nullptr); // apply mask
// for (int t = 0; t < time; ++t) {
// xt = timestep(xmt, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// zt = timestep(&wiZ, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// ft = timestep(&wiF, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// rt = timestep(&wiR, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// ct = timestep(state, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// ht = timestep(h, t, t+1); // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
// // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
// ft->addRowVector(&bF, ft);
// rt->addRowVector(&bR, rt);
// ft->applyTransform(transform::Sigmoid, ft, nullptr);
// rt->applyTransform(transform::Sigmoid, rt, nullptr);
// // ct = ft * c_t-1 + (1 - ft) * zt,
// ft->applyPairwiseTransform(pairwise::Multiply, ct_1, ct, nullptr);
// ft->applyTransform(transform::OneMinus, ft);
// ft->applyPairwiseTransform(pairwise::Multiply, *zt, nullptr);
// ct->applyPairwiseTransform(pairwise::Add, *ft, nullptr);
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct->applyTransform(transform::Tanh, gct);
// // ht = rt * gct + (1 - rt) * xt
// rt->applyPairwiseTransform(pairwise::Multiply, gct, ht, nullptr);
// rt->applyTransform(transform::OneMinus, rt);
// rt->applyPairwiseTransform(pairwise::Multiply, *xt, nullptr);
// ht->applyPairwiseTransform(pairwise::Add, *rt, nullptr);
// delete xt; delete zt; delete ft; delete rt; delete ht; delete ct_1;
// ct_1 = ct;
// }
// delete wi; delete ct_1; delete gct; delete xmt;
// return Status::OK();
// }
// DECLARE_TYPES(sru_old) {
// getOpDescriptor()
// ->setAllowedInputTypes(nd4j::DataType::ANY)
// ->setAllowedOutputTypes({ALL_FLOATS});
// }
// DECLARE_SHAPE_FN(sru_old) {
// auto inShape = inputShape->at(0); // [bS x inSize x time]
// int rank = inShape[0]; // = 3
// int size = rank*2 + 4;
// auto bS = inShape[1];
// auto inSize = inShape[2];
// int time = inShape[3];
// char order = (char)(inShape[size-1]);
// Nd4jLong *newShapeInfo1 = nullptr;
// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
// newShapeInfo1[0] = rank;
// newShapeInfo1[1] = bS;
// newShapeInfo1[2] = inSize;
// newShapeInfo1[3] = time;
// ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
// auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newShapeInfo1));
// RELEASE(newShapeInfo1, block.getWorkspace());
// return SHAPELIST(result, result);
// }
// static NDArray sigmoid_(const NDArray& arr) {
// NDArray result(arr.getShapeInfo(), false, arr.getContext());
// (const_cast<NDArray&>(arr)).applyTransform(transform::Sigmoid, &result);
// return result;
// }
//////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(sru_bp_logic, 8, 4, true, 0, 0) {
// auto x = INPUT_VARIABLE(0); // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
// auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [3*inSize x inSize]
// auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 × 2*inSize]
// auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
// auto c = INPUT_VARIABLE(4); // C, [bS x inSize x time]
// auto inGradCt = INPUT_VARIABLE(5); // [bS x inSize]
// auto inGradH = INPUT_VARIABLE(6); // [bS x inSize x time]
// auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
// auto gradX = OUTPUT_VARIABLE(0); // [bS x inSize x time]
// auto gradW = OUTPUT_VARIABLE(1); // [bS x 3*inSize x inSize]
// auto gradB = OUTPUT_VARIABLE(2); // [2*inSize]
// auto gradInit = OUTPUT_VARIABLE(3); // [bS x inSize]
// // input shapes validation
// const int rank = 3;
// REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BP operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
// REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
// REQUIRE_TRUE(b->rankOf() <= 2, 0, "SRU_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
// REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
// REQUIRE_TRUE(c->rankOf() == rank, 0, "SRU_BP operation: wrong rank of cell states array, expected is %i, but got %i instead !", rank, c->rankOf());
// REQUIRE_TRUE(inGradCt->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of array of cell state gradient, expected is %i, but got %i instead !", rank-1, inGradCt->rankOf());
// REQUIRE_TRUE(inGradH->rankOf() == rank, 0, "SRU_BP operation: wrong rank of array of cell outputs gradients, expected is %i, but got %i instead !", rank, inGradH->rankOf());
// if(mask)
// REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
// const int bS = x->shapeOf()[0];
// const int inSize = x->shapeOf()[1];
// const int time = x->shapeOf()[2]; // time - number of time steps
// const std::string wShape = ShapeUtils::shapeAsString(w);
// const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
// // const std::string bShape = ShapeUtils::shapeAsString(b);
// // const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
// const std::string c0Shape = ShapeUtils::shapeAsString(c0);
// const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
// const std::string cShape = ShapeUtils::shapeAsString(c);
// const std::string cCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time});
// const std::string inGradCtShape = ShapeUtils::shapeAsString(inGradCt);
// const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize});
// const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH);
// const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time});
// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
// REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
// REQUIRE_TRUE(cShape == cCorrectShape, 0, "SRU_BP operation: wrong shape of cell states array, expected is %s, but got %s instead !", cCorrectShape.c_str(), cShape.c_str());
// REQUIRE_TRUE(inGradCtShape == inGradCtCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell state gradient, expected is %s, but got %s instead !", inGradCtCorrectShape.c_str(), inGradCtShape.c_str());
// REQUIRE_TRUE(inGradHShape == inGradHCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell outputs gradients, expected is %s, but got %s instead !", inGradHCorrectShape.c_str(), inGradHShape.c_str());
// if(mask) {
// const std::string maskShape = ShapeUtils::shapeAsString(mask);
// REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
// }
// const auto bF = (*b)({0,0, 0, inSize}); // biases for forget gate [1 x inSize]
// const auto bR = (*b)({0,0, inSize,2*inSize}); // biases for reset gate [1 x inSize]
// NDArray gradBias(x->ordering(), {bS, 2*inSize, time}, x->dataType(), block.launchContext());
// NDArray gradU (x->ordering(), {bS, 3*inSize, time}, x->dataType(), block.launchContext());
// NDArray gradHX (x->ordering(), {bS, inSize, time}, x->dataType(), block.launchContext());
// NDArray gct (c->ordering(), {bS, inSize}, x->dataType(), block.launchContext());
// // x = x * mask
// if(mask)
// x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr); // apply mask
// // multiplication matrix wi = matmul(w,x), U = WX
// const auto wi = mmul(*w, *x); // U [bS x 3K x time]
// for (int t = time-1; t >=0 ; --t) {
// // initialization
// auto xt = (*x)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
// auto zt = wi({0,0, 0, inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize]
// auto ft = wi({0,0, inSize, 2*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize]
// auto rt = wi({0,0, 2*inSize,3*inSize, t,t+1}); // [bS x 3K x time] -> [bS x inSize]
// auto ct = (*c)({0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
// auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1}
// ///////////////// forward
// // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
// ft = sigmoid_(ft + bF);
// rt = sigmoid_(rt + bR);
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct.applyTransform(transform::Tanh, &gct);
// ///////////////// backward
// // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
// // ftMinus = -ft + (T)1.;
// NDArray ftMinus = 1. - ft;
// NDArray rtMinus = 1. - rt;
// NDArray gradBRt = inGradHt * (gct - xt) * rtMinus * rt;
// // bF, TODO - tanh
// NDArray gradTanh = 1. - gct * gct;
// NDArray gradCt = inGradHt * rt * gradTanh;
// NDArray gradBFt = (gradCt + *inGradCt) * (ct_1 - zt) * ftMinus * ft;
// // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
// NDArray gradHXt = inGradHt * rtMinus;
// // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
// NDArray gradUZt = (inGradHt * rt * gradTanh + *inGradCt) * ftMinus;
// // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
// *inGradCt = (gradCt + *inGradCt) * ft;
// // save results
// gradBias({0,0, 0,inSize, t,t+1}, true).assign(gradBFt);
// gradBias({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBRt);
// gradU({0,0, 0,inSize, t,t+1}, true).assign(gradUZt);
// gradU({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBFt);
// gradU({0,0, 2*inSize, 3*inSize, t,t+1}, true).assign(gradBRt);
// gradHX({0,0, 0,0, t,t+1}, true).assign(gradHXt);
// }
// // gradInit
// gradInit->assign(inGradCt);
// // gradX
// w->transposei(); // [inSize x 3K]
// gradX->assign( mmul(*w, gradU) + gradHX);
// if(mask)
// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
// // gradB
// gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K]
// // gradW [bS x 3K x inSize]
// x->permutei({0, 2, 1}); // [bS x time x inSize]
// gradW->assign( mmul(gradU, *x) );
// return Status::OK();
// }
// DECLARE_TYPES(sru_bp_logic) {
// getOpDescriptor()
// ->setAllowedInputTypes(nd4j::DataType::ANY)
// ->setAllowedOutputTypes({ALL_FLOATS});
// }
// DECLARE_SHAPE_FN(sru_bp_logic) {
// auto inShape = inputShape->at(0); // [bS x inSize x time]
// auto bS = inShape[1];
// auto inSize = inShape[2];
// auto time = inShape[3];
// char order = shape::order(inShape);
// ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time});
// ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize});
// ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize});
// ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
// return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
// }