REQUIRE_TRUE(numOfAxes<=inRank,0,"BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !",numOfAxes,inRank);
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
REQUIRE_TRUE(mean->isSameShape(expShape),0,"BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->isSameShape(expShape),0,"BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(variance).c_str());
REQUIRE_TRUE(gamma->isSameShape(expShape),0,"BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(gamma).c_str());
REQUIRE_TRUE(beta->isSameShape(expShape),0,"BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(beta).c_str());
axes.push_back(inRank-1);// default dimension to reduce along is last dimension
constintnumOfAxes=axes.size();
REQUIRE_TRUE(numOfAxes<=inRank,0,"BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !",numOfAxes,inRank);
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
std::vector<Nd4jLong>expShape;
if(numOfAxes==1)
expShape.push_back(input->sizeAt(axes[0]));
else{// get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
expShape=std::vector<Nd4jLong>(inRank,1);
for(uinti=0;i<numOfAxes;++i)
expShape[axes[i]]=input->sizeAt(axes[i]);
}
REQUIRE_TRUE(mean->isSameShape(expShape),0,"BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->isSameShape(expShape),0,"BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(variance).c_str());
if(gamma)
REQUIRE_TRUE(gamma->isSameShape(expShape),0,"BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(gamma).c_str());
if(beta)
REQUIRE_TRUE(beta->isSameShape(expShape),0,"BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(expShape).c_str(),ShapeUtils::shapeAsString(beta).c_str());
REQUIRE_TRUE(input->isSameShape(dLdO),0,"BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !",ShapeUtils::shapeAsString(input).c_str(),ShapeUtils::shapeAsString(dLdO).c_str());
// types of all input arrays should be the same (except dLdO)
for(inti=1;i<block.width()-1;++i)
if(i!=3)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType()==INPUT_VARIABLE(i)->dataType(),0,"BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !");