autox=INPUT_VARIABLE(0);// input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter
autoWxFW=INPUT_VARIABLE(1);// input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
autoWhFW=INPUT_VARIABLE(2);// hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW]
autobFW=INPUT_VARIABLE(3);// biases for forward RNN, [2*numUnitsFW]
autoWxBW=INPUT_VARIABLE(4);// input-to-hidden weights for backward RNN, [inSize x numUnitsBW]
autoWhBW=INPUT_VARIABLE(5);// hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW]
autobBW=INPUT_VARIABLE(6);// biases for backward RNN, [2*v]
NDArray*h0FW=nullptr;// initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW]
NDArray*h0BW=nullptr;// initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW]
NDArray*maxTimeStep=nullptr;// vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
constinttimeMajor=block.getIArguments()->size()>0?INT_ARG(0):0;// if non zero then [time, bS, ...], else [bS, time, ...]
switch(block.width()){
case8:
maxTimeStep=INPUT_VARIABLE(7);
break;
case9:
h0FW=INPUT_VARIABLE(7);
h0BW=INPUT_VARIABLE(8);
break;
case10:
h0FW=INPUT_VARIABLE(7);
h0BW=INPUT_VARIABLE(8);
maxTimeStep=INPUT_VARIABLE(9);
break;
}
autohFW=OUTPUT_VARIABLE(0);// cell outputs for forward RNN [time x bS x numUnitsFW] or [bS x time x numUnitsFW], shape depends on timeMajor parameter
autohBW=OUTPUT_VARIABLE(1);// cell outputs for backward RNN [time x bS x numUnitsBW] or [bS x time x numUnitsBW], shape depends on timeMajor parameter
autohFWFinal=OUTPUT_VARIABLE(2);// final cell out for forward RNN [bS x numUnitsFW]
autohBWFinal=OUTPUT_VARIABLE(3);// final cell out for backward RNN [bS x numUnitsBF]
REQUIRE_TRUE(x->rankOf()==3,0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !",x->rankOf());
REQUIRE_TRUE(WxFW->rankOf()==2,0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !",WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf()==2,0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !",WxBW->rankOf());
constintinRank=x->rankOf();
constinttime=timeMajor?x->sizeAt(0):x->sizeAt(1);
constintbS=timeMajor?x->sizeAt(1):x->sizeAt(0);
constintnumUnitsFW=WxFW->sizeAt(1);
constintnumUnitsBW=WxBW->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW)==ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW)==ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW)==ShapeUtils::shapeAsString({2*numUnitsFW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(),ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW)==ShapeUtils::shapeAsString({2*numUnitsBW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(),ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW)==ShapeUtils::shapeAsString({bS,numUnitsFW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(h0FW).c_str());
if(h0BW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW)==ShapeUtils::shapeAsString({bS,numUnitsBW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(h0BW).c_str());
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep)==ShapeUtils::shapeAsString({bS}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !",bS,ShapeUtils::shapeAsString(maxTimeStep).c_str());
autox=INPUT_VARIABLE(0);// input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter
autoWxFW=INPUT_VARIABLE(1);// input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
autoWhFW=INPUT_VARIABLE(2);// hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW]
autobFW=INPUT_VARIABLE(3);// biases for forward RNN, [2*numUnitsFW]
autoWxBW=INPUT_VARIABLE(4);// input-to-hidden weights for backward RNN, [inSize x numUnitsBW]
autoWhBW=INPUT_VARIABLE(5);// hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW]
autobBW=INPUT_VARIABLE(6);// biases for backward RNN, [2*numUnitsBW]
NDArray*h0FW=nullptr;// initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW]
NDArray*h0BW=nullptr;// initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW]
NDArray*maxTimeStep=nullptr;// vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
constinttimeMajor=block.getIArguments()->size()>0?INT_ARG(0):0;// if true then [time, bS, ...], else [bS, time, ...]
switch(block.width()){
case8:
maxTimeStep=INPUT_VARIABLE(7);
break;
case9:
h0FW=INPUT_VARIABLE(7);
h0BW=INPUT_VARIABLE(8);
break;
case10:
h0FW=INPUT_VARIABLE(7);
h0BW=INPUT_VARIABLE(8);
maxTimeStep=INPUT_VARIABLE(9);
break;
}
REQUIRE_TRUE(x->rankOf()==3,0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !",x->rankOf());
REQUIRE_TRUE(WxFW->rankOf()==2,0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !",WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf()==2,0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !",WxBW->rankOf());
constintinRank=x->rankOf();
constinttime=timeMajor?x->sizeAt(0):x->sizeAt(1);
constintbS=timeMajor?x->sizeAt(1):x->sizeAt(0);
constintnumUnitsFW=WxFW->sizeAt(1);
constintnumUnitsBW=WxBW->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW)==ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW)==ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW)==ShapeUtils::shapeAsString({2*numUnitsFW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(),ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW)==ShapeUtils::shapeAsString({2*numUnitsBW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(),ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW)==ShapeUtils::shapeAsString({bS,numUnitsFW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(h0FW).c_str());
if(h0BW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW)==ShapeUtils::shapeAsString({bS,numUnitsBW}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(h0BW).c_str());
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep)==ShapeUtils::shapeAsString({bS}),0,"DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !",bS,ShapeUtils::shapeAsString(maxTimeStep).c_str());