autox=INPUT_VARIABLE(0);// input [time x bS x inSize]
autoWxFW=INPUT_VARIABLE(1);// input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
autoWhFW=INPUT_VARIABLE(2);// hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW]
autobFW=INPUT_VARIABLE(3);// biases for forward RNN, [2*numUnitsFW]
autoWxBW=INPUT_VARIABLE(4);// input-to-hidden weights for backward RNN, [inSize x numUnitsBW]
autoWhBW=INPUT_VARIABLE(5);// hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW]
autobBW=INPUT_VARIABLE(6);// biases for backward RNN, [2*v]
NDArray*h0FW=nullptr;// initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW]
NDArray*h0BW=nullptr;// initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW]
NDArray*maxTimeStep=nullptr;// vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
switch(block.width()){
case8:
maxTimeStep=INPUT_VARIABLE(7);
break;
case9:
h0FW=INPUT_VARIABLE(7);
h0BW=INPUT_VARIABLE(8);
break;
case10:
h0FW=INPUT_VARIABLE(7);
h0BW=INPUT_VARIABLE(8);
maxTimeStep=INPUT_VARIABLE(9);
break;
}
autoh=OUTPUT_VARIABLE(0);// cell outputs [time x bS x (numUnitsFW + numUnitsBW)], that is per each time step
autohFWFinal=OUTPUT_VARIABLE(1);// final cell out for forward RNN [bS x numUnitsFW]
autohBWFinal=OUTPUT_VARIABLE(2);// final cell out for backward RNN [bS x numUnitsBF]
REQUIRE_TRUE(x->rankOf()==3,0,"STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !",x->rankOf());
REQUIRE_TRUE(WxFW->rankOf()==2,0,"STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !",WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf()==2,0,"STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !",WxBW->rankOf());
constNd4jLonginRank=x->rankOf();
constNd4jLongtime=x->sizeAt(0);
constNd4jLongbS=x->sizeAt(1);
constNd4jLongnumUnitsFW=WxFW->sizeAt(1);
constNd4jLongnumUnitsBW=WxBW->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW)==ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW)==ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW)==ShapeUtils::shapeAsString({2*numUnitsFW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(),ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW)==ShapeUtils::shapeAsString({2*numUnitsBW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(),ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW)==ShapeUtils::shapeAsString({bS,numUnitsFW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(h0FW).c_str());
if(h0BW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW)==ShapeUtils::shapeAsString({bS,numUnitsBW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(h0BW).c_str());
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep)==ShapeUtils::shapeAsString({bS}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !",bS,ShapeUtils::shapeAsString(maxTimeStep).c_str());
autoxShapeInfo=inputShape->at(0);// input [time x bS x inSize]
autoWxFWShapeInfo=inputShape->at(1);// input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
autoWhFWShapeInfo=inputShape->at(2);// hidden-to-hidden weights for forward RNN, [numUnitsFW x numUnitsFW]
autobFWShapeInfo=inputShape->at(3);// biases for forward RNN, [2*numUnitsFW]
autoWxBWShapeInfo=inputShape->at(4);// input-to-hidden weights for backward RNN, [inSize x numUnitsBW]
autoWhBWShapeInfo=inputShape->at(5);// hidden-to-hidden weights for backward RNN, [numUnitsBW x numUnitsBW]
autobBWShapeInfo=inputShape->at(6);// biases for backward RNN, [2*numUnitsBW]
Nd4jLong*h0FWShapeInfo=nullptr;// initial cell output for forward RNN (at time step = 0) [bS x numUnitsFW]
Nd4jLong*h0BWShapeInfo=nullptr;// initial cell output for backward RNN (at time step = 0) [bS x numUnitsBW]
Nd4jLong*maxTimeStepShapeInfo=nullptr;// vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
switch(block.width()){
case8:
maxTimeStepShapeInfo=inputShape->at(7);
break;
case9:
h0FWShapeInfo=inputShape->at(7);
h0BWShapeInfo=inputShape->at(8);
break;
case10:
h0FWShapeInfo=inputShape->at(7);
h0BWShapeInfo=inputShape->at(8);
maxTimeStepShapeInfo=inputShape->at(9);
break;
}
REQUIRE_TRUE(xShapeInfo[0]==3,0,"STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !",xShapeInfo[0]);
REQUIRE_TRUE(WxFWShapeInfo[0]==2,0,"STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !",WxFWShapeInfo[0]);
REQUIRE_TRUE(WxBWShapeInfo[0]==2,0,"STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !",WxBWShapeInfo[0]);
constintinRank=xShapeInfo[0];
constinttime=xShapeInfo[1];
constintbS=xShapeInfo[2];
constintnumUnitsFW=WxFWShapeInfo[2];
constintnumUnitsBW=WxBWShapeInfo[2];
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFWShapeInfo)==ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsFW,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(WhFWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBWShapeInfo)==ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({numUnitsBW,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(WhBWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFWShapeInfo)==ShapeUtils::shapeAsString({2*numUnitsFW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(),ShapeUtils::shapeAsString(bFWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBWShapeInfo)==ShapeUtils::shapeAsString({2*numUnitsBW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !",ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(),ShapeUtils::shapeAsString(bBWShapeInfo).c_str());
if(h0FWShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FWShapeInfo)==ShapeUtils::shapeAsString({bS,numUnitsFW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsFW}).c_str(),ShapeUtils::shapeAsString(h0FWShapeInfo).c_str());
if(h0BWShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BWShapeInfo)==ShapeUtils::shapeAsString({bS,numUnitsBW}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !",ShapeUtils::shapeAsString({bS,numUnitsBW}).c_str(),ShapeUtils::shapeAsString(h0BWShapeInfo).c_str());
if(maxTimeStepShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStepShapeInfo)==ShapeUtils::shapeAsString({bS}),0,"STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !",bS,ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());