REQUIRE_TRUE(x->rankOf()==4,0,"CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !",x->rankOf());
intbS=x->sizeAt(0);// batch size
intiH,iW,iD;// input height, input width, input depth(number of channels)
if(dataFormat){
iD=x->sizeAt(1);
iH=x->sizeAt(2);
iW=x->sizeAt(3);
}
else{
iD=x->sizeAt(3);
iH=x->sizeAt(1);
iW=x->sizeAt(2);
}
autoxCast=x->cast(sd::DataType::FLOAT32);
REQUIRE_TRUE(scale->rankOf()==1&&scale->sizeAt(0)==iD,0,"CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead",iD,ShapeUtils::shapeAsString(scale).c_str());
REQUIRE_TRUE(offset->rankOf()==1&&offset->sizeAt(0)==iD,0,"CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead",iD,ShapeUtils::shapeAsString(offset).c_str());
NDArray*mean(nullptr),*variance(nullptr);
if(!isTraining){
mean=INPUT_VARIABLE(3);
variance=INPUT_VARIABLE(4);
REQUIRE_TRUE(mean->rankOf()==1&&mean->sizeAt(0)==iD,0,"CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead",iD,ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->rankOf()==1&&variance->sizeAt(0)==iD,0,"CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead",iD,ShapeUtils::shapeAsString(variance).c_str());
}
else{
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
REQUIRE_TRUE(scaleShapeInfo[0]==1&&scaleShapeInfo[1]==iD,0,"CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead",iD,ShapeUtils::shapeAsString(scaleShapeInfo).c_str());