Shyrma bn mkl bp (#14)

* - write code for new batchnorm backprop

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing batchnorm backprop

Signed-off-by: Yurii <iuriish@yahoo.com>

* - write code for batchnorm backprop based on mkl dnn api

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing and fixing bugs in batchnorm_bp mkl dnn

Signed-off-by: Yurii <iuriish@yahoo.com>

* - made corrections required by reviewer

Signed-off-by: Yurii <iuriish@yahoo.com>

* - change name in java wrapper for batchnorm op

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2019-10-26 14:14:21 +03:00 committed by GitHub
parent d333d29099
commit 029a69a835
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1280 additions and 699 deletions

View File

@ -60,6 +60,7 @@ namespace nd4j {
Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor);
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape);
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo);
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, nd4j::memory::Workspace *workspace);
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);

View File

@ -99,6 +99,10 @@ namespace nd4j {
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) {
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
}
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) {
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();

View File

@ -102,6 +102,10 @@ namespace nd4j {
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) {
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
}
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) {
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();

View File

@ -29,84 +29,8 @@ namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
auto input = INPUT_VARIABLE(0);
auto mean = INPUT_VARIABLE(1);
auto variance = INPUT_VARIABLE(2);
NDArray *gamma = nullptr;
NDArray *beta = nullptr;
auto output = OUTPUT_VARIABLE(0);
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
// FIXME: double?
const double epsilon = T_ARG(0);
if(applyScale)
gamma = INPUT_VARIABLE(3);
if(applyOffset)
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
std::vector<const NDArray*> inArrs(block.width());
for(int i = 0; i < block.width(); ++i)
inArrs[i] = INPUT_VARIABLE(i);
// check whether all input shapes are mutually broadcastable
Nd4jLong* outShapeInfo = nullptr;
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !");
// normalized output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt);
if(applyScale)
sigmaInvGam *= *gamma;
NDArray inputMinusMean;
if(!input->isSameShape(output) && !mean->isSameShape(output)) {
auto inputTiled = NDArray(output, false, block.launchContext());
input->tile(inputTiled);
inputMinusMean = inputTiled - *mean;
}
else
inputMinusMean = *input - *mean;
if (applyOffset)
output->assign(inputMinusMean * sigmaInvGam + *beta);
else
output->assign(inputMinusMean * sigmaInvGam);
return Status::OK();
}
DECLARE_TYPES(batchnorm) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(batchnorm) {
std::vector<const NDArray*> inArrs(block.width());
auto in = inputShape->at(0);
for(int i = 0; i < block.width(); ++i)
inArrs[i] = INPUT_VARIABLE(i);
// check whether all input shapes are mutually broadcastable
Nd4jLong* outShapeInfo = nullptr;
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !");
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, DataTypeUtils::pickFloatingType(ArrayOptions::dataType(in))));
return SHAPELIST(result);
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
auto input = INPUT_VARIABLE(0);
auto mean = INPUT_VARIABLE(1);
@ -123,7 +47,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
if(applyScale)
gamma = INPUT_VARIABLE(3);
if(applyOffset)
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
beta = INPUT_VARIABLE(3 + (int)applyScale);
const int numOfIntArgs = block.getIArguments()->size();
const int inRank = input->rankOf();
@ -137,30 +61,31 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
const int numOfAxes = axes.size();
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_NEW op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
// get, for example, something like {1, inDim1, 1, inDim3, 1} if axes = {1, 3}
std::vector<Nd4jLong> expShapeWithUnities(inRank, 1);
for(int i = 0; i < numOfAxes; ++i)
expShapeWithUnities[axes[i]] = input->sizeAt(axes[i]);
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
std::vector<Nd4jLong> expShape = numOfAxes == 1 ? std::vector<Nd4jLong>(1, input->sizeAt(axes[0])) : expShapeWithUnities;
std::string expShapeStr = ShapeUtils::shapeAsString(expShape);
std::vector<Nd4jLong> expShape;
if(numOfAxes == 1)
expShape.push_back(input->sizeAt(axes[0]));
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
expShape = std::vector<Nd4jLong>(inRank, 1);
for(uint i = 0; i < numOfAxes; ++i)
expShape[axes[i]] = input->sizeAt(axes[i]);
}
REQUIRE_TRUE(ShapeUtils::shapeAsString(mean) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of mean array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(variance) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of variance array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(variance).c_str());
REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str());
if(gamma)
REQUIRE_TRUE(ShapeUtils::shapeAsString(gamma) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of gamma array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(gamma).c_str());
REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str());
if(beta)
REQUIRE_TRUE(ShapeUtils::shapeAsString(beta) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of beta array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(beta).c_str());
REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str());
// types of all input arrays should be the same
for(int i = 1; i < block.width(); ++i)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !");
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM op: types of all input arrays should be the same !");
nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0);
nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0);
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon);
@ -168,11 +93,11 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
return Status::OK();
}
DECLARE_TYPES(batchnorm_new) {
DECLARE_TYPES(batchnorm) {
getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true);
}
DECLARE_SHAPE_FN(batchnorm_new) {
DECLARE_SHAPE_FN(batchnorm) {
auto inShapeInfo = inputShape->at(0);
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo));
@ -184,98 +109,129 @@ DECLARE_SHAPE_FN(batchnorm_new) {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
auto input = INPUT_VARIABLE(0);
auto mean = INPUT_VARIABLE(1);
auto variance = INPUT_VARIABLE(2);
NDArray* input = INPUT_VARIABLE(0);
NDArray* mean = INPUT_VARIABLE(1);
NDArray* variance = INPUT_VARIABLE(2);
NDArray* dLdO = INPUT_VARIABLE(3); // next epsilon
NDArray* gamma = nullptr;
NDArray* beta = nullptr;
NDArray *dLdO = nullptr; // next epsilon
auto dLdI = OUTPUT_VARIABLE(0);
auto dLdM = OUTPUT_VARIABLE(1);
auto dLdV = OUTPUT_VARIABLE(2);
NDArray* dLdI = OUTPUT_VARIABLE(0);
NDArray* dLdM = OUTPUT_VARIABLE(1);
NDArray* dLdV = OUTPUT_VARIABLE(2);
NDArray* dLdG = nullptr;
NDArray* dLdB = nullptr;
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
// FIXME: double?
const double epsilon = T_ARG(0);
const int dLdONum = static_cast<int>(applyScale) + static_cast<int>(applyOffset);
const float epsilon = T_ARG(0);
if(applyScale) {
gamma = INPUT_VARIABLE(3);
gamma = INPUT_VARIABLE(4);
dLdG = OUTPUT_VARIABLE(3);
}
if(applyOffset) {
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
dLdB = OUTPUT_VARIABLE(3 + static_cast<int>(applyScale));
beta = INPUT_VARIABLE(4 + (int)applyScale);
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
}
dLdO = INPUT_VARIABLE(3 + dLdONum);
const int numOfIntArgs = block.getIArguments()->size();
const int inRank = input->rankOf();
std::vector<const NDArray*> inArrs(block.width());
for(int i = 0; i < 4 + dLdONum; ++i)
inArrs[i] = INPUT_VARIABLE(i);
// get axes args to normalize input array over
std::vector<int> axes;
if(numOfIntArgs > 2)
for(int i = 2; i < numOfIntArgs; ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
// check whether all input shapes are mutually broadcastable
Nd4jLong* outShapeInfo = nullptr;
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !");
const int numOfAxes = axes.size();
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
std::vector<Nd4jLong> expShape;
if(numOfAxes == 1)
expShape.push_back(input->sizeAt(axes[0]));
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
expShape = std::vector<Nd4jLong>(inRank, 1);
for(uint i = 0; i < numOfAxes; ++i)
expShape[axes[i]] = input->sizeAt(axes[i]);
}
REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str());
if(gamma)
REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str());
if(beta)
REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str());
REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str());
// types of all input arrays should be the same (except dLdO)
for(int i = 1; i < block.width() - 1; ++i)
if(i != 3)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
// ***** calculations ***** //
auto sigmaInv = (*variance + epsilon).transform(transform::RSqrt);
// formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
NDArray sigmaInvGamdLdO = -sigmaInv * *dLdO;
// consider mean and variance as constants (since we get them as inputs and don't calculate them)
// dLdI = (dLdO * gamma) / (variance + epsilon)^0.5
// dLdV = (-0.5 * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5
// dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5
// dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5
// dLdB = dLdO_sum
const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes);
NDArray temp1 = *variance + epsilon;
temp1.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
auto temp2 = temp1.transform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
if(applyScale)
sigmaInvGamdLdO *= *gamma;
temp2 *= *gamma; // gamma / (variance + epsilon)^0.5
NDArray inputMinusMean;
if(!input->isSameShape(dLdO) && !mean->isSameShape(dLdO)) {
auto inputTiled = NDArray(dLdO, false, block.launchContext());
input->tile(inputTiled);
inputMinusMean = inputTiled - *mean;
}
else
inputMinusMean = *input - *mean;
NDArray temp3(input); // empty array with same shape as input
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3); // input - mean
temp3 *= *dLdO; // (input - mean) * dLdO
const bool keepUnitiesInShape = inRank == mean->rankOf();
// dLdI
if(!dLdI->isSameShape(dLdO))
dLdI->assign( (-sigmaInvGamdLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdI->getShapeInfo(), dLdO->getShapeInfo())) );
else
dLdI->assign(-sigmaInvGamdLdO);
dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI);
// dLdM
if(!dLdM->isSameShape(dLdO))
dLdM->assign( sigmaInvGamdLdO.reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdM->getShapeInfo(), dLdO->getShapeInfo())) );
else
dLdM->assign(sigmaInvGamdLdO);
dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); // dLdO sum over excluded axes
// dLdB
if(applyOffset)
dLdB->assign(dLdM);
// dLdM
// dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
// dLdM->applyTransform(nd4j::transform::Neg);
*dLdM = 0; // put zeros so far
//dLdV
if(!dLdV->isSameShape(dLdO)) {
dLdV->assign( (sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdV->getShapeInfo(), dLdO->getShapeInfo())) );
}
else
dLdV->assign(sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f);
temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // ((input - mean) * dLdO)_sum
// dLdG
if(applyScale) {
if(!dLdG->isSameShape(dLdO))
dLdG->assign( (sigmaInv * inputMinusMean * *dLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdG->getShapeInfo(), dLdO->getShapeInfo())) );
else
dLdG->assign(sigmaInv * inputMinusMean * *dLdO);
dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG);
// dLdV->assign(dLdG);
dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma);
}
else
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
// dLdB
if(applyOffset) {
if(!dLdB->isSameShape(dLdO))
dLdB->assign(dLdO->reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdB->getShapeInfo(), dLdO->getShapeInfo())) );
else
dLdB->assign(dLdO);
}
// dLdV
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1);
// *dLdV *= -0.5;
*dLdV = 0; // put zeros so far
return Status::OK();
}
@ -285,9 +241,9 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, nd4j::DataType::ANY)
->setAllowedInputTypes(2, nd4j::DataType::ANY)
->setAllowedInputTypes(3, nd4j::DataType::ANY)
->setAllowedInputTypes(3, {ALL_FLOATS})
->setAllowedInputTypes(4, nd4j::DataType::ANY)
->setAllowedInputTypes(5, {ALL_FLOATS})
->setAllowedInputTypes(5, nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
@ -295,179 +251,35 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
DECLARE_SHAPE_FN(batchnorm_bp) {
Nd4jLong* inShapeInfo = inputShape->at(0);
Nd4jLong* meanShapeInfo = inputShape->at(1);
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
const int dLdONum = static_cast<int>(applyScale) + static_cast<int>(applyOffset);
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo));
std::vector<const NDArray*> inArrs(block.width());
for(int i = 0; i < 4 + dLdONum; ++i)
inArrs[i] = INPUT_VARIABLE(i);
auto shapes = SHAPELIST();
// check whether all input shapes are mutually broadcastable
Nd4jLong* outShapeInfo = nullptr;
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !");
// dLdI shapeInfo
shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, inShapeInfo));
Nd4jLong* dLdIShapeInfo(nullptr), *dLdMShapeInfo(nullptr), *dLdVShapeInfo(nullptr), *dLdGShapeInfo(nullptr), *dLdBShapeInfo(nullptr);
COPY_SHAPE(inputShape->at(0), dLdIShapeInfo);
COPY_SHAPE(inputShape->at(1), dLdMShapeInfo);
COPY_SHAPE(inputShape->at(2), dLdVShapeInfo);
// dLdM shapeInfo
shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, meanShapeInfo));
if(applyScale) {
COPY_SHAPE(inputShape->at(3), dLdGShapeInfo);
// dLdV shapeInfo (same as dLdM)
shapes->push_back(shapes->at(shapes->size()-1));
// dLdG shapeInfo (same as dLdM)
if(applyScale)
shapes->push_back(shapes->at(shapes->size()-1));
// dLdB shapeInfo (same as dLdM)
if(applyOffset)
shapes->push_back(shapes->at(shapes->size()-1));
return shapes;
}
if(applyOffset){
COPY_SHAPE(inputShape->at(3 + static_cast<int>(applyScale)), dLdBShapeInfo);
}
if(!applyScale && !applyOffset)
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo));
if(applyScale && !applyOffset)
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo));
if(!applyScale && applyOffset)
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdBShapeInfo));
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo), CONSTANT(dLdBShapeInfo));
}
// //////////////////////////////////////////////////////////////////////////
// CONFIGURABLE_OP_IMPL(batchnorm_bp, 5, 1, true, 0, 1) {
// NDArray<T>* input = INPUT_VARIABLE(0);
// NDArray<T>* epsilon = INPUT_VARIABLE(1);
// NDArray<T>* gamma = INPUT_VARIABLE(2);
// NDArray<T>* dGlobalMeanView = INPUT_VARIABLE(3);
// NDArray<T>* dGlobalVarView = INPUT_VARIABLE(4);
// NDArray<T>* outEpsilon = this->getZ(block);
// std::vector<int> argI = *(block.getIArguments());
// const int bS = epsilon->sizeAt(0);
// bool isLockGammaBeta = (bool)argI[0];
// const int* epsilonShape = epsilon->getShapeInfo() + 1;
// const T eps = (T)1e-5;
// int rank = epsilon->rankOf();
// std::initializer_list<int> dimensions;
// int effectiveBatchSize;
// if (rank == 2) {
// dimensions = {0};
// effectiveBatchSize = bS;
// }
// else if (rank == 4) {
// dimensions = {0, 2, 3};
// effectiveBatchSize = input->sizeAt(0)*input->sizeAt(2)*input->sizeAt(3);
// }
// else
// throw "Graph operation batchnorm_bp: the epsilon rank must be equal to 2 or 4 !";
// NDArray<T> *mean(nullptr), *var(nullptr), *dBeta(nullptr), *dGamma(nullptr), *dLdVar(nullptr), *dxmu1(nullptr), *dxmu2(nullptr);
// mean = input->template reduceAlongDimension<simdOps::Mean<T>>(dimensions);
// var = input->template varianceAlongDimension<simdOps::SummaryStatsVariance<T>>(false, dimensions);
// var->template applyScalar<simdOps::Add<T>>(eps, nullptr);
// auto std = new NDArray<T>(var->getShapeInfo(), block.getWorkspace());
// var->template applyTransform<simdOps::Sqrt<T>>(std, nullptr);
// auto xMu = new NDArray<T>(input->getShapeInfo(), block.getWorkspace());
// auto xHat = new NDArray<T>(input->getShapeInfo(), block.getWorkspace());
// auto temp1 = new NDArray<T>(epsilon->getShapeInfo(), block.getWorkspace());
// auto temp2 = new NDArray<T>(std->getShapeInfo(), block.getWorkspace());
// auto dGammaView = new NDArray<T>('c', {1, epsilonShape[1]}, block.getWorkspace());
// auto dBetaView = new NDArray<T>('c', {1, epsilonShape[1]}, block.getWorkspace());
// auto dxhat = new NDArray<T>(epsilon->getShapeInfo(), block.getWorkspace());
// if (rank == 2) {
// input->subRowVector(mean, xMu);
// xMu->divRowVector(std, xHat);
// }
// else {
// input->template applyBroadcast<simdOps::Subtract<T>>({1}, mean, xMu, nullptr);
// xMu->template applyBroadcast<simdOps::Divide<T>>({1}, std, xHat, nullptr);
// }
// dBeta = epsilon->sum(dimensions); // dL/dBeta = sum_examples dL/dOut
// epsilon->template applyPairwiseTransform<simdOps::Multiply<T>>(xHat, temp1, nullptr); //dL/dGamma = sum_examples dL/dOut .* xHat
// dGamma = temp1->sum(dimensions); //dL/dGamma = sum_examples dL/dOut .* xHat
// if (isLockGammaBeta)
// epsilon->template applyPairwiseTransform<simdOps::Multiply<T>>(gamma, dxhat, nullptr);
// else {// Standard case
// if(rank == 2)
// epsilon->mulRowVector(gamma, dxhat); //dL/dxHat = dL/dOut . gamma Shape: [minibatchSize, nOut]
// else
// epsilon->template applyBroadcast<simdOps::Multiply<T>>({1}, gamma, dxhat, nullptr);
// }
// // dLdVar - dL/dVariance, shape: [1, miniBatch]
// dxhat->template applyPairwiseTransform<simdOps::Multiply<T>>(xMu, temp1, nullptr);
// dLdVar = temp1->sum(dimensions);
// dLdVar->template applyScalar<simdOps::Multiply<T>>((T)-0.5, nullptr);
// T powParams[] = {(T)(-3.)};
// std->template applyTransform<simdOps::Pow<T>>(temp2, powParams);
// dLdVar->template applyPairwiseTransform<simdOps::Multiply<T>>(temp2, nullptr);
// //dL/dmu
// dxmu1 = dxhat->sum(dimensions);
// dxmu1->template applyPairwiseTransform<simdOps::Divide<T>>(std, nullptr);
// dxmu1->template applyTransform<simdOps::Neg<T>>();
// dxmu2 = xMu->sum(dimensions);
// dxmu2->template applyScalar<simdOps::Multiply<T>>((T)(-2.)/effectiveBatchSize);
// dxmu2->template applyPairwiseTransform<simdOps::Multiply<T>>(dLdVar, nullptr);
// dxmu1->template applyPairwiseTransform<simdOps::Add<T>>(dxmu2, nullptr);
// NDArray<T>* dLdmu = dxmu1; // = dL/dmu Shape: [1, nOut]
// //Note the array reuse here: dxhat, xMu, dLdVar, dLdmu - all are invalid after this line (but aren't used later anyway)
// NDArray<T>* dLdx = dxhat;
// dLdVar->template applyScalar<simdOps::Multiply<T>>((T)(2.)/effectiveBatchSize);
// dLdmu->template applyScalar<simdOps::Multiply<T>>((T)(1.)/effectiveBatchSize);
// if(rank == 2) {
// dLdx->divRowVector(std, dLdx);
// xMu->mulRowVector(dLdVar, xMu);
// }
// else {
// dLdx->template applyBroadcast<simdOps::Divide<T>>({1}, std, dLdx, nullptr);
// xMu->template applyBroadcast<simdOps::Multiply<T>>({1}, dLdVar, xMu, nullptr);
// }
// dLdx->template applyPairwiseTransform<simdOps::Add<T>>(xMu, nullptr);
// if(rank == 2)
// dLdx->addRowVector(dLdmu, dLdx);
// else
// dLdx->template applyBroadcast<simdOps::Add<T>>({1}, dLdmu, dLdx, nullptr);
// *outEpsilon = *dLdx;
// //TODO rework this to avoid the assign here
// // dGammaView->assign(dGamma);
// // dBetaView->assign(dBeta);
// // dGlobalMeanView->assign((T)0.);
// // dGlobalVarView->assign((T)0.);
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, dGlobalMeanView);
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, dGlobalVarView);
// delete std;
// delete xMu;
// delete xHat;
// delete mean;
// delete var;
// delete dBeta;
// delete dGamma;
// delete dLdVar;
// delete dxmu1;
// delete dxmu2;
// delete temp1;
// delete temp2;
// delete dxhat;
// delete dGammaView;
// delete dBetaView;
// return ND4J_STATUS_OK;
// }
}

View File

@ -91,9 +91,6 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_batchnorm)
DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2);
#endif
#if NOT_EXCLUDED(OP_batchnorm_new)
DECLARE_CUSTOM_OP(batchnorm_new, 3, 1, false, 1, 2);
#endif
/**
* back prop in batch normalization
@ -117,8 +114,8 @@ namespace nd4j {
* dL/dInput
* dL/dMean
* dL/dVariance
* dL/dGamma
* dL/dBeta
* dL/dGamma, optional
* dL/dBeta, optional
*/
#if NOT_EXCLUDED(OP_batchnorm)
DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2);

View File

@ -32,6 +32,8 @@ namespace helpers {
template <typename T>
static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector<int>& axes, const double epsilon) {
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
NDArray sigmaInvGam(mean); // do not copy mean's buffer, take only its shapeInfo
T eps = epsilon;

View File

@ -17,6 +17,7 @@
//
// @author saudet
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/PlatformHelper.h>
@ -28,19 +29,296 @@
#include <ops/declarable/helpers/convolutions.h>
#include <NDArrayFactory.h>
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
PLATFORM_IMPL(batchnorm_new) {
auto input = INPUT_VARIABLE(0);
auto mean = INPUT_VARIABLE(1);
auto variance = INPUT_VARIABLE(2);
NDArray *gamma = nullptr;
NDArray *beta = nullptr;
auto output = OUTPUT_VARIABLE(0);
//////////////////////////////////////////////////////////////////////////
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
// also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
// mean -> 1D [c]
// variance -> 1D [c]
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
// z(output) - same shape as x
const int xRank = x->rankOf();
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// input type
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
// indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats;
if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift;
mkldnn::memory::dims dims;
mkldnn::memory::format_tag format;
if(xRank == 2) {
dims = {x->sizeAt(0), x->sizeAt(1)};
format = mkldnn::memory::format_tag::nc;
}
else if(xRank == 4) {
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
format = mkldnn::memory::format_tag::nchw;
}
else { // xRank = 5
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
format = mkldnn::memory::format_tag::ncdhw;
}
// memory descriptors for arrays
// x
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
if(xRank > 2) {
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
}
if(xRank > 4)
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// z, output
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format);
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
if(xRank > 2) {
z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2];
z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3];
}
if(xRank > 4)
z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4];
// batchnorm forward description
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
mkldnn::stream stream(engine);
// provide memory and check whether reorder is required
// x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
// z
auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer());
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
if (zReorder)
mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
args[MKLDNN_ARG_DST] = z_mkl_mem;
// mean
auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer());
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
// variance
auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer());
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
// gamma and beta (and their gradients) if they are present
if(weights != nullptr) {
auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer());
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
}
// run calculations
mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
const float epsilon, NDArray* dLdI, NDArray* dLdW) {
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
// also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
// mean -> 1D [c]
// variance -> 1D [c]
// dLdO - same shape as x
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
// dLdI - same shape as x
// dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta
const int xRank = x->rankOf();
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// input type
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
// indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats;
if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift;
mkldnn::memory::dims dims;
mkldnn::memory::format_tag format;
if(xRank == 2) {
dims = {x->sizeAt(0), x->sizeAt(1)};
format = mkldnn::memory::format_tag::nc;
}
else if(xRank == 4) {
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
format = mkldnn::memory::format_tag::nchw;
}
else { // xRank = 5
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
format = mkldnn::memory::format_tag::ncdhw;
}
// memory descriptors for arrays
// x
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
if(xRank > 2) {
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
}
if(xRank > 4)
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
// dLdO
mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format);
dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
if(xRank > 2) {
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2];
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3];
}
if(xRank > 4)
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
// dLdI
mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format);
mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format);
dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
if(xRank > 2) {
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2];
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3];
}
if(xRank > 4)
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
// batchnorm forward description
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
// batchnorm backprop description
mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags);
mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
mkldnn::stream stream(engine);
// provide memory and check whether reorder is required
// x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[MKLDNN_ARG_SRC] = x_mkl_mem;
// dLdO
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer());
const bool dLdOReorder = op_bp_prim_desc.diff_src_desc() != dLdO_user_mem.get_desc();
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdO_user_mem;
if (dLdOReorder)
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem;
// mean
auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
// variance
auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer());
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
// dLdI
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer());
const bool dLdIReorder = op_bp_prim_desc.diff_dst_desc() != dLdI_user_mem.get_desc();
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdI_user_mem;
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem;
// gamma and beta (and their gradients) if they are present
if(weights != nullptr) {
auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer());
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer());
args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem;
}
// run calculations
mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (dLdIReorder)
mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
stream.wait();
// shape::printArray(dLdI_mkl_mem.map_data<float>(),8);
}
PLATFORM_IMPL(batchnorm) {
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
auto mean = INPUT_VARIABLE(1); // [c]
auto variance = INPUT_VARIABLE(2); // [c]
NDArray* gamma = nullptr; // [c]
NDArray* beta = nullptr; // [c]
auto output = OUTPUT_VARIABLE(0); // same shape as input
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
@ -49,118 +327,381 @@ namespace nd4j {
if(applyScale)
gamma = INPUT_VARIABLE(3);
if(applyOffset)
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
beta = INPUT_VARIABLE(3 + (int)applyScale);
const int numOfIntArgs = block.getIArguments()->size();
const int inRank = input->rankOf();
// get axes args to normalize input array over
std::vector<int> axes;
if (block.numI() > 2)
for (int i = 2; i < block.numI(); ++i)
if(numOfIntArgs > 2)
for(int i = 2; i < numOfIntArgs; ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(input->rankOf() - 1);
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
weights({0, 1, 0, 0}).assign(1.0f);
weights({1, 2, 0, 0}).assign(0.0f);
const int numOfAxes = axes.size();
REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes);
REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank);
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str());
if(gamma != nullptr)
REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str());
if(beta != nullptr)
REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str());
mkldnn_memory_desc_t empty;
mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(
empty), user_dst_md(empty);
// types of all input arrays should be the same (except dLdO)
for(int i = 1; i < block.width() - 1; ++i)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_MKLDNN op: types of all input arrays should be the same !");
auto norm_flag = normalization_flags::use_global_stats;
if (applyScale || applyOffset)
norm_flag |= normalization_flags::use_scale_shift;
mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
&batchnorm_src_md, nullptr, &batchnorm_dst_md,
&user_src_md, nullptr, &user_dst_md, axes[0]);
NDArray *weights = nullptr;
auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon, norm_flag);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn::stream stream(engine);
auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
mean->buffer());
auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
variance->buffer());
auto batchnorm_src_memory = user_src_memory;
mkldnn::memory m(batchnorm_src_md, engine);
if (m.get_desc() != user_src_memory.get_desc()) {
batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
batchnorm_src_memory);
}
auto batchnorm_dst_memory = user_dst_memory;
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
}
if(applyScale || applyOffset) {
if (gamma != nullptr) {
weights({0, 1, 0, 0}).assign(gamma);
}
if (beta != nullptr) {
weights({1, 2, 0, 0}).assign(beta);
weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType());
if(applyScale)
(*weights)({0,1, 0,0}).assign(gamma);
else
(*weights)({0,1, 0,0}).assign(1);
if(applyOffset)
(*weights)({1,2, 0,0}).assign(beta);
else
(*weights)({1,2, 0,0}).assign(0);
}
auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
{MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
} else {
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
}
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
user_dst_memory);
}
stream.wait();
batchnormMKLDNN(input, mean, variance, weights, epsilon, output);
delete weights;
return Status::OK();
}
PLATFORM_CHECK(batchnorm_new) {
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(batchnorm) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
if (::optimalLevel() < 2)
return false;
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE(0);
auto mean = INPUT_VARIABLE(1);
auto variance = INPUT_VARIABLE(2);
NDArray *gamma = nullptr;
NDArray *beta = nullptr;
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
auto mean = INPUT_VARIABLE(1); // [c]
auto variance = INPUT_VARIABLE(2); // [c]
NDArray* gamma = nullptr; // [c]
NDArray* beta = nullptr; // [c]
auto output = OUTPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); // same shape as input
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
const double epsilon = T_ARG(0);
if(applyScale)
gamma = INPUT_VARIABLE(3);
if(applyOffset)
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
beta = INPUT_VARIABLE(3 + (int)applyScale);
const int numOfIntArgs = block.getIArguments()->size();
std::vector<int> axes;
if (block.numI() > 2)
for (int i = 2; i < block.numI(); ++i)
if(numOfIntArgs > 2)
for(int i = 2; i < numOfIntArgs; ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(input->rankOf() - 1);
axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension
DataType inputType = input->dataType();
DataType meanType = mean->dataType();
DataType varType = variance->dataType();
DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32;
DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32;
DataType outType = output->dataType();
const int inRank = input->rankOf();
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32);
}
//////////////////////////////////////////////////////////////////////////
// PLATFORM_IMPL(batchnorm) {
// auto input = INPUT_VARIABLE(0);
// auto mean = INPUT_VARIABLE(1);
// auto variance = INPUT_VARIABLE(2);
// NDArray *gamma = nullptr;
// NDArray *beta = nullptr;
// auto output = OUTPUT_VARIABLE(0);
// const bool applyScale = (bool) INT_ARG(0);
// const bool applyOffset = (bool) INT_ARG(1);
// const double epsilon = T_ARG(0);
// if (applyScale)
// gamma = INPUT_VARIABLE(3);
// if (applyOffset)
// beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
// std::vector<int> axes;
// if (block.numI() > 2)
// for (int i = 2; i < block.numI(); ++i)
// axes.push_back(INT_ARG(i));
// else
// axes.push_back(input->rankOf() - 1);
// std::vector<Nd4jLong> shape({2, mean->lengthOf()});
// NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
// weights({0, 1, 0, 0}).assign(1.0f);
// weights({1, 2, 0, 0}).assign(0.0f);
// mkldnn_memory_desc_t empty;
// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
// auto flag = mkldnn::normalization_flags::use_global_stats;
// if (applyScale || applyOffset)
// flag |= mkldnn::normalization_flags::use_scale_shift;
// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
// &batchnorm_src_md, nullptr, &batchnorm_dst_md,
// &user_src_md, nullptr, &user_dst_md, axes[0]);
// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// mkldnn::stream stream(engine);
// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
// mean->buffer());
// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
// variance->buffer());
// auto batchnorm_src_memory = user_src_memory;
// mkldnn::memory m(batchnorm_src_md, engine);
// if (m.get_desc() != user_src_memory.get_desc()) {
// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
// batchnorm_src_memory);
// }
// auto batchnorm_dst_memory = user_dst_memory;
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
// }
// if (applyScale || applyOffset) {
// if (gamma != nullptr) {
// weights({0, 1, 0, 0}).assign(gamma);
// }
// if (beta != nullptr) {
// weights({1, 2, 0, 0}).assign(beta);
// }
// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// } else {
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// }
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
// user_dst_memory);
// }
// stream.wait();
// return Status::OK();
// }
//////////////////////////////////////////////////////////////////////////
// PLATFORM_CHECK(batchnorm) {
// // we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
// auto input = INPUT_VARIABLE(0);
// auto mean = INPUT_VARIABLE(1);
// auto variance = INPUT_VARIABLE(2);
// NDArray *gamma = nullptr;
// NDArray *beta = nullptr;
// auto output = OUTPUT_VARIABLE(0);
// const bool applyScale = (bool) INT_ARG(0);
// const bool applyOffset = (bool) INT_ARG(1);
// const double epsilon = T_ARG(0);
// if (applyScale)
// gamma = INPUT_VARIABLE(3);
// if (applyOffset)
// beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
// std::vector<int> axes;
// if (block.numI() > 2)
// for (int i = 2; i < block.numI(); ++i)
// axes.push_back(INT_ARG(i));
// else
// axes.push_back(input->rankOf() - 1);
// return block.isUseMKLDNN() &&
// nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
// axes.size() == 1;
// }
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(batchnorm_bp) {
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
NDArray* mean = INPUT_VARIABLE(1); // [c]
NDArray* variance = INPUT_VARIABLE(2); // [c]
NDArray* dLdO = INPUT_VARIABLE(3); // same as input
NDArray* gamma = nullptr; // [c]
NDArray* beta = nullptr; // [c]
NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input
NDArray* dLdM = OUTPUT_VARIABLE(1); // [c]
NDArray* dLdV = OUTPUT_VARIABLE(2); // [c]
NDArray* dLdG = nullptr; // [c]
NDArray* dLdB = nullptr; // [c]
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
const float epsilon = T_ARG(0);
if(applyScale) {
gamma = INPUT_VARIABLE(4);
dLdG = OUTPUT_VARIABLE(3);
}
if(applyOffset) {
beta = INPUT_VARIABLE(4 + (int)applyScale);
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
}
const int numOfIntArgs = block.getIArguments()->size();
const int inRank = input->rankOf();
// get axes args to normalize input array over
std::vector<int> axes;
if(numOfIntArgs > 2)
for(int i = 2; i < numOfIntArgs; ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
const int numOfAxes = axes.size();
REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes);
REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank);
REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str());
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str());
if(gamma != nullptr)
REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str());
if(beta != nullptr)
REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str());
// types of all input arrays should be the same (except dLdO)
for(int i = 1; i < block.width() - 1; ++i)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !");
NDArray *weights = nullptr, *dLdW = nullptr;
if(applyScale || applyOffset) {
weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType());
dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType());
if(applyScale)
(*weights)({0,1, 0,0}).assign(gamma);
else
(*weights)({0,1, 0,0}).assign(1);
if(applyOffset)
(*weights)({1,2, 0,0}).assign(beta);
else
(*weights)({1,2, 0,0}).assign(0);
}
*dLdM = 0;
*dLdV = 0;
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW);
if(applyScale || applyOffset) {
if(applyScale)
dLdG->assign((*dLdW)({0,1, 0,0}));
if(applyOffset)
dLdB->assign((*dLdW)({1,2, 0,0}));
delete weights;
delete dLdW;
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(batchnorm_bp) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
NDArray* mean = INPUT_VARIABLE(1); // [c]
NDArray* variance = INPUT_VARIABLE(2); // [c]
NDArray* dLdO = INPUT_VARIABLE(3); // same as input
NDArray* gamma = nullptr; // [c]
NDArray* beta = nullptr; // [c]
NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input
NDArray* dLdM = OUTPUT_VARIABLE(1); // [c]
NDArray* dLdV = OUTPUT_VARIABLE(2); // [c]
NDArray* dLdG = nullptr; // [c]
NDArray* dLdB = nullptr; // [c]
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
if(applyScale) {
gamma = INPUT_VARIABLE(4);
dLdG = OUTPUT_VARIABLE(3);
}
if(applyOffset) {
beta = INPUT_VARIABLE(4 + (int)applyScale);
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
}
const int numOfIntArgs = block.getIArguments()->size();
std::vector<int> axes;
if(numOfIntArgs > 2)
for(int i = 2; i < numOfIntArgs; ++i)
axes.push_back(INT_ARG(i));
else
axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension
DataType inputType = input->dataType();
DataType meanType = mean->dataType();
DataType varType = variance->dataType();
DataType dLdOType = dLdO->dataType();
DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32;
DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32;
DataType dLdIType = dLdI->dataType();
DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32;
DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32;
const int inRank = input->rankOf();
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 &&
dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32);
}
return block.isUseMKLDNN() &&
nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
axes.size() == 1;
}
}
}
}

View File

@ -132,8 +132,6 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn_memory_desc_t empty;
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;

View File

@ -305,50 +305,50 @@ namespace nd4j {
};
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
const Nd4jLong* shape = src->getShapeInfo();
Nd4jLong rank = shape[0];
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
// mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
// mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
// const Nd4jLong* shape = src->getShapeInfo();
// Nd4jLong rank = shape[0];
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
// mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
auto type = mkldnn::memory::data_type::f32;
auto format = mkldnn::memory::format_tag::nchw;
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
// auto type = mkldnn::memory::data_type::f32;
// auto format = mkldnn::memory::format_tag::nchw;
// auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
*batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_src_md->data.format_kind = mkldnn_blocked; // overrides format
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
}
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
// *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
// user_src_md->data.format_kind = mkldnn_blocked; // overrides format
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
// }
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
*batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
}
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
// *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
// user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
// }
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
*batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
*user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
}
};
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
// *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
// *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
// user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
// }
// };
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,

View File

@ -62,7 +62,9 @@ namespace nd4j{
DECLARE_PLATFORM(lrn);
DECLARE_PLATFORM(batchnorm_new);
DECLARE_PLATFORM(batchnorm);
DECLARE_PLATFORM(batchnorm_bp);
DECLARE_PLATFORM(lstmLayer);
}

View File

@ -413,7 +413,7 @@ namespace nd4j {
return ctx;
};
nd4j::ops::batchnorm_new batchnorm;
nd4j::ops::batchnorm batchnorm;
DeclarableBenchmark benchmark(batchnorm, "batchnorm");
output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization");

View File

@ -2385,129 +2385,6 @@ TEST_F(DeclarableOpsTests1, CompactLaunchTests2) {
ASSERT_TRUE(exp.equalsTo(&z));
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, batchnorm_test1) {
auto input = NDArrayFactory::create<double>('c', {2,3,2,3,2});
auto mean = NDArrayFactory::create<double>('c', {2,3,2,3,2});
auto variance = NDArrayFactory::create<double>('c', {2,3,2,3,2});
auto gamma = NDArrayFactory::create<double>('c', {2,3,2,3,2});
auto beta = NDArrayFactory::create<double>('c', {2,3,2,3,2});
auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011,10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369});
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
gamma.assign(1.2);
beta.assign(1.);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
TEST_F(DeclarableOpsTests1, batchnorm_test2) {
auto input = NDArrayFactory::create<double>('c', {2,3,1,3,1});
auto mean = NDArrayFactory::create<double>('c', {1,3,2,1,2});
auto variance = NDArrayFactory::create<double>('c', {2,1,2,3,2});
auto gamma = NDArrayFactory::create<double>('c', {2,3,2,3,1});
auto beta = NDArrayFactory::create<double>('c', {1,3,2,1,2});
auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144});
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
gamma.assign(1.2);
beta.assign(1.);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, batchnorm_test3) {
auto input = NDArrayFactory::create<double>('c', {2,3,2,3,2});
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,1});
auto gamma = NDArrayFactory::create<double>('c', {1,1});
auto beta = NDArrayFactory::create<double>('c', {1,2});
auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011, 10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369});
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
gamma.assign(1.2);
beta.assign(1.);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, batchnorm_test4) {
auto input = NDArrayFactory::create<double>('c', {3,2});
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
auto variance= NDArrayFactory::create<double>('c', {2,3,1,3,2});
auto gamma = NDArrayFactory::create<double>('c', {1,1});
auto beta = NDArrayFactory::create<double>('c', {1,2});
auto expected= NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428});
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
gamma.assign(1.2);
beta.assign(1.);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests1, sru_old_test1) {

View File

@ -2313,7 +2313,35 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
// output->printBuffer();
ASSERT_TRUE(expected.isSameShapeStrict(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) {
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
auto mean = NDArrayFactory::create<TypeParam>('c', {4});
@ -2330,7 +2358,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
gamma.assign(1.2);
beta.assign(1.);
nd4j::ops::batchnorm_new op;
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
@ -2346,7 +2374,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) {
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) {
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
auto mean = NDArrayFactory::create<TypeParam>('c', {3}, {1.05, 1.1, 1.15});
@ -2359,7 +2387,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) {
input.linspace(0.1, 0.1);
nd4j::ops::batchnorm_new op;
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
@ -2374,7 +2402,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) {
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) {
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) {
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
auto mean = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4});
@ -2387,7 +2415,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) {
input.linspace(0.1, 0.1);
nd4j::ops::batchnorm_new op;
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2});
@ -2401,6 +2429,63 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, batchnorm_test5) {
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20.,
-19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438,
-12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
// output->printBuffer();
ASSERT_TRUE(expected.isSameShapeStrict(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, batchnorm_test6) {
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 ,
20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 ,
-12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) {

View File

@ -2883,78 +2883,336 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
auto input = NDArrayFactory::create<double>('c', {3,2});
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,2});
auto gamma = NDArrayFactory::create<double>('c', {1,1});
auto beta = NDArrayFactory::create<double>('c', {1,2});
auto dLdO = NDArrayFactory::create<double>('c', {2,3,2,3,2});
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556, 0., 0.254556,0.509112, 0.763668, 1.018224, 1.272779,
1.527335, 1.781891, 2.036447, 2.291003,2.545559, 2.800115, 3.054671, 3.309227,3.563783, 3.818338, 4.072894, 4.32745}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
gamma.assign(1.2);
beta.assign(1.);
// beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &beta, &dLdO}, {1e-5}, {1,1});
nd4j::ops::batchnorm_bp op;
nd4j::ops::batchnorm opFF;
nd4j::ops::batchnorm_bp opBP;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
ASSERT_TRUE(isGradCorrect);
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
auto input = NDArrayFactory::create<double>('c', {2,3,2,3,2});
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,1});
auto gamma = NDArrayFactory::create<double>('c', {1,1});
auto dLdO = NDArrayFactory::create<double>('c', {2,3,2,3,2});
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray mean ('c', {3}, {1.05, 1.1, 1.15});
NDArray variance('c', {3}, {0.5, 0.6, 0.7});
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4});
NDArray beta ('c', {3}, nd4j::DataType::DOUBLE);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742, 0., 0.251742,0.501992, 0.752989, 1.003985, 1.254981,
1.527335, 1.781891, 2.036447, 2.291003,2.517418, 2.76916 , 3.020902, 3.272644,3.513947, 3.764943, 4.015939, 4.266936});
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388});
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4});
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
gamma.assign(1.2);
// beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma}, {1e-5}, {1,0});
const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &dLdO}, {1e-5}, {1,0});
nd4j::ops::batchnorm_bp op;
nd4j::ops::batchnorm opFF;
nd4j::ops::batchnorm_bp opBP;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
ASSERT_TRUE(isGradCorrect);
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
auto input = NDArrayFactory::create<double>('c', {2,3,1,3});
auto mean = NDArrayFactory::create<double>('c', {1,3,2,1});
auto variance = NDArrayFactory::create<double>('c', {2,1,2,3});
auto dLdO = NDArrayFactory::create<double>('c', {2,3,2,3});
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4});
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2});
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9});
NDArray beta ('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742, 0., 0.251556,0.509112, 0.755225, 1.003985, 1.25778 ,
1.517885, 1.784991, 2.05947 , 2.341504,2.529808, 2.804986, 3.089205, 3.382173,3.541731, 3.824981, 4.11894 , 4.422841});
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 });
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85});
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
// beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
const OpArgsHolder argsHolderFF({&input, &mean, &variance}, {1e-5}, {0,0});
const OpArgsHolder argsHolderBP({&input, &mean, &variance, &dLdO}, {1e-5}, {0,0});
nd4j::ops::batchnorm_bp op;
nd4j::ops::batchnorm opFF;
nd4j::ops::batchnorm_bp opBP;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2});
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
ASSERT_TRUE(isGradCorrect);
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {1.442483, 0.9502 , 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2}, {1.527335, 1.272779,1.018224, 0.763668,-0.466136, -0.233068,0., 0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169, 1.501697,1.716225, 1.930753,
-2.545559, -2.800115,-3.054671, -3.309227,3.262951, 3.496019,3.729087, 3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618, 4.934146,5.148675, 5.363203}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528, -0.509112, 0.699204, -0.885433, 1.072641, -1.527335, 1.631475, -1.770866, 1.930753,
-2.545559, 2.563747, -2.656298, 2.788865, -3.563783, 3.496019, -3.541731, 3.646978, -4.582006, 4.42829 , -4.427164, 4.50509 , -5.60023 , 5.360562, -5.312597, 5.363203}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) {
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584,0.509112, -0.233068, -0., 0.214528,-0.509112, 0.699204, -0.885433, 1.072641,-1.527335, 1.631475, -1.770866,
1.930753,-2.545559, 2.563747, -2.656298, 2.788865,-3.563783, 3.496019, -3.541731, 3.646978,-4.582006, 4.42829 , -4.427164,
4.50509 ,-5.60023 , 5.360562, -5.312597, 5.363203, -6.618453, 6.292834, -6.19803 , 6.221315,-7.636677, 7.225105, -7.083463,
7.079428,-8.6549 , 8.157377, -7.968895, 7.93754 ,-9.673124, 9.089649, -8.854328, 8.795652, -10.691348, 10.02192 , -9.739761,
9.653765,-11.709571, 10.954192, -10.625194, 10.511877,-12.727795, 11.886464, -11.510627, 11.36999 ,-13.746018, 12.818735, -12.39606 , 12.228102}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
// dLdI->printBuffer();
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2,2}, {1.527335, 1.272779, 1.018224, 0.763668, 0.509112, 0.254556, -0. , -0.254556, 0.466136, 0.699204, 0.932272, 1.16534 , 1.398407, 1.631475, 1.864543, 2.097611,
-2.213582, -2.43494 , -2.656298, -2.877657, -3.099015, -3.320373, -3.541731, -3.76309 , 3.861506, 4.076034, 4.290562, 4.50509 , 4.719618, 4.934146, 5.148675, 5.363203,
-6.618453, -6.873009, -7.127565, -7.382121, -7.636677, -7.891233, -8.145789, -8.400345, 7.924309, 8.157377, 8.390445, 8.623513, 8.856581, 9.089649, 9.322717, 9.555784,
-9.297045, -9.518403, -9.739761, -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405, 10.940933, 11.155462, 11.36999 , 11.584518, 11.799046, 12.013574, 12.228102}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto dLdI = results->at(0);
auto dLdG = results->at(3);
auto dLdB = results->at(4);
// dLdI->printBuffer();
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
delete results;
}
/*
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {

View File

@ -64,7 +64,7 @@ TEST_F(MklDnnTests, helpers_includer) {
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp;
nd4j::ops::platforms::PLATFORM_lrn lrn;
nd4j::ops::platforms::PLATFORM_batchnorm_new batchnorm;
nd4j::ops::platforms::PLATFORM_batchnorm batchnorm;
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm});
#endif

View File

@ -142,7 +142,7 @@ public class BatchNorm extends DynamicCustomOp {
@Override
public String opName() {
return "batchnorm_new";
return "batchnorm";
}
@Override