Shyrma bnorm bp (#41)

Batchnorm backprop mkldnn
master
Yurii Shyrma 2019-11-12 10:58:48 +02:00 committed by raver119
parent cd961727bb
commit 0eda1e733e
3 changed files with 284 additions and 117 deletions

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -88,8 +89,27 @@ CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0);
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
// auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
// auto m = input->reduceAlongDimension(nd4j::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon);
// NDArray stdInv = *v + epsilon;
// stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
// stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
// if(applyScale)
// stdInv *= *gamma;
// // empty array with same shape as input
// input->applyBroadcast(nd4j::broadcast::Subtract, axes, m, output);
// output->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv);
// if(applyOffset)
// output->applyBroadcast(nd4j::broadcast::Add, axes, beta);
// delete v;
// delete m;
return Status::OK();
}
@ -113,10 +133,9 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
NDArray* input = INPUT_VARIABLE(0);
NDArray* mean = INPUT_VARIABLE(1);
NDArray* variance = INPUT_VARIABLE(2);
NDArray* dLdO = INPUT_VARIABLE(3); // next epsilon
NDArray* gamma = nullptr;
NDArray* beta = nullptr;
NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // next epsilon
NDArray* dLdI = OUTPUT_VARIABLE(0);
NDArray* dLdM = OUTPUT_VARIABLE(1);
@ -129,11 +148,11 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
const float epsilon = T_ARG(0);
if(applyScale) {
gamma = INPUT_VARIABLE(4);
gamma = INPUT_VARIABLE(3);
dLdG = OUTPUT_VARIABLE(3);
}
if(applyOffset) {
beta = INPUT_VARIABLE(4 + (int)applyScale);
beta = INPUT_VARIABLE(3 + (int)applyScale);
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
}
@ -172,67 +191,120 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str());
// types of all input arrays should be the same (except dLdO)
for(int i = 1; i < block.width() - 1; ++i)
if(i != 3)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
for(int i = 1; i < block.width() - 2; ++i)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
// ***** calculations ***** //
// formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
// notations:
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
// g = dLdO
// stdInv = 1 / (v + eps)^0.5
// N - batch size (product of spatial dimensions)
// consider mean and variance as constants (since we get them as inputs and don't calculate them)
// dLdI = (dLdO * gamma) / (variance + epsilon)^0.5
// dLdV = (-0.5 * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5
// dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5
// dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5
// dLdB = dLdO_sum
// derivatives:
// dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)
// dfdx = gamma*stdInv*g;
// dfdm = -gamma*stdInv*g_sum;
// dmdx = 1/N;
// dvdx = 2 * (x - m) / N
// dvdm = -2 * [(x - m)]_sum / N
// dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience
// finally:
// dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - m)) )
// dLdG = (g * (x - m))_sum * stdInv
// dLdB = g_sum
// variance = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
// mean = input->reduceAlongDimension(nd4j::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes);
NDArray temp1 = *variance + epsilon;
temp1.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
auto temp2 = temp1.transform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
if(applyScale)
temp2 *= *gamma; // gamma / (variance + epsilon)^0.5
NDArray temp3(input); // empty array with same shape as input
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3); // input - mean
temp3 *= *dLdO; // (input - mean) * dLdO
const bool keepUnitiesInShape = inRank == mean->rankOf();
// dLdI
dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI);
// inverse batch size 1/N
const float Ninv = 1.f * shape::tadLength(input->getShapeInfo(), axes.data(), axes.size()) / input->lengthOf();
// dLdM
dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); // dLdO sum over excluded axes
// input - mean
NDArray xMinusMean(input); // empty array with same shape as input
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean);
// stdInv
NDArray stdInv = *variance + epsilon;
stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
// dvdm (use dLdM as storage for dvdm)
xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape);
*dLdM *= -Ninv;
// g_sum
auto gSum = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape);
// dLdB
if(applyOffset)
dLdB->assign(dLdM);
dLdB->assign(gSum);
// dLdM
// dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
// dLdM->applyTransform(nd4j::transform::Neg);
*dLdM = 0; // put zeros so far
// stdInv * (g - g_sum/N) (use dLdI as storage for this expression)
gSum *= Ninv;
dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, &gSum, dLdI);
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv);
//dLdV
temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // ((input - mean) * dLdO)_sum
// dLdV <- [g*(x - m)]_sum
(xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape);
// dLdG
if(applyScale) {
dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG);
// dLdV->assign(dLdG);
dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma);
}
else
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
*dLdV *= stdInv;
if(applyScale)
dLdG->assign(dLdV);
// dLdV
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1);
// *dLdV *= -0.5;
// (2 / N) * dfdv (use dLdV as storage for dfdv)
*dLdV *= stdInv*stdInv; // dLdV*stdInv * stdInv^2
*dLdV *= -Ninv; // -0.5f * (2 / N);
// dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression)
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dLdM);
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV);
// dLdI
*dLdI += xMinusMean;
if(applyScale)
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma);
*dLdM = 0; // put zeros so far
*dLdV = 0; // put zeros so far
// java code
// NDArray std = *variance + epsilon;
// std.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
// std.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
// NDArray xMu(input);
// input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMu);
// NDArray xHat(input);
// xMu.applyBroadcast(nd4j::broadcast::Multiply, axes, &std, &xHat);
// NDArray dxhat(input);
// dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma, &dxhat);
// NDArray temp = dxhat*xMu;
// temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape);
// *dLdV *= -0.5f * std*std*std;
// NDArray* dxmu1 = dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape);
// *dxmu1 *= -std;
// NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape);
// *dxmu2 *= *dLdV * (-2.f/N);
// NDArray dLdmu = *dxmu1 + *dxmu2;
// dLdmu *= (1.f /N);
// *dLdV *= (2.f/N);
// dxhat.applyBroadcast(nd4j::broadcast::Multiply, axes, &std);
// xMu.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV);
// dxhat += xMu;
// dxhat.applyBroadcast(nd4j::broadcast::Add, axes, &dLdmu, dLdI);
// delete dxmu1;
// delete dxmu2;
// xHat *= *dLdO;
// xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, keepUnitiesInShape);
return Status::OK();
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -55,7 +56,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
// indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats;
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift;
@ -182,7 +183,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
// indicate whether gamma or/and beta are given
auto flags = mkldnn::normalization_flags::use_global_stats;
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
if (weights != nullptr)
flags |= mkldnn::normalization_flags::use_scale_shift;
@ -308,6 +309,70 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
stream.wait();
// shape::printArray(dLdI_mkl_mem.map_data<float>(),8);
// notations:
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
// g = dLdO
// stdInv = 1 / (v + eps)^0.5
// N - batch size (product of spatial dimensions)
// formula for full derivative with respect to input (x)
// dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)
// !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!!
// dfdm = -gamma*stdInv*g_sum;
// dmdx = 1/N;
// dvdx = 2 * (x - m) / N
// dvdm = -2 * [(x - m)]_sum / N
// dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience
// finally:
// dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m))
// dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) )
std::vector<int> axes = {1};
const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes);
// inversed batch size 1 / N
const auto Ninv = 1.f * mean->lengthOf() / x->lengthOf();
// x - mean
NDArray xMinusMean(x); // empty array with same shape as x
const_cast<NDArray*>(x)->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean);
// stdInv
NDArray stdInv = *variance + epsilon;
stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
// dfdm / N
auto dfdm = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes);
dfdm *= stdInv;
dfdm *= -Ninv;
// dvdm / 2
NDArray dvdm(mean); // empty array with same shape as mean
xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, &dvdm, excludedAxes);
dvdm *= -Ninv;
// (2/N)*dfdv
NDArray dfdv(variance); // empty array with same shape as variance
(xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, &dfdv, excludedAxes);
dfdv *= stdInv*stdInv*stdInv;
dfdv *= -Ninv;
// dvdm/2 + (x - m)
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dvdm);
// dfdv * (dvdm/2 + (x - m))
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &dfdv);
// add dfdm / N
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dfdm);
// * gamma
auto gamma = (*weights)({0,1, 0,0});
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &gamma);
*dLdI += xMinusMean;
}
PLATFORM_IMPL(batchnorm) {
@ -371,10 +436,21 @@ PLATFORM_IMPL(batchnorm) {
(*weights)({1,2, 0,0}).assign(0);
}
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
input = new NDArray(input->permute(permut));
output = new NDArray(output->permute(permut));
}
batchnormMKLDNN(input, mean, variance, weights, epsilon, output);
delete weights;
if(axes[0] == inRank - 1 && inRank > 2) {
delete input;
delete output;
}
return Status::OK();
}
@ -418,7 +494,7 @@ PLATFORM_CHECK(batchnorm) {
const int inRank = input->rankOf();
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) &&
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32);
}
@ -558,29 +634,29 @@ PLATFORM_CHECK(batchnorm) {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(batchnorm_bp) {
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
NDArray* mean = INPUT_VARIABLE(1); // [c]
NDArray* variance = INPUT_VARIABLE(2); // [c]
NDArray* dLdO = INPUT_VARIABLE(3); // same as input
NDArray* gamma = nullptr; // [c]
NDArray* beta = nullptr; // [c]
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
NDArray* mean = INPUT_VARIABLE(1); // [c]
NDArray* variance = INPUT_VARIABLE(2); // [c]
NDArray* gamma = nullptr; // [c]
NDArray* beta = nullptr; // [c]
NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // same as input
NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input
NDArray* dLdM = OUTPUT_VARIABLE(1); // [c]
NDArray* dLdV = OUTPUT_VARIABLE(2); // [c]
NDArray* dLdG = nullptr; // [c]
NDArray* dLdB = nullptr; // [c]
NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input
NDArray* dLdM = OUTPUT_VARIABLE(1); // [c]
NDArray* dLdV = OUTPUT_VARIABLE(2); // [c]
NDArray* dLdG = nullptr; // [c]
NDArray* dLdB = nullptr; // [c]
const bool applyScale = (bool)INT_ARG(0);
const bool applyOffset = (bool)INT_ARG(1);
const float epsilon = T_ARG(0);
if(applyScale) {
gamma = INPUT_VARIABLE(4);
gamma = INPUT_VARIABLE(3);
dLdG = OUTPUT_VARIABLE(3);
}
if(applyOffset) {
beta = INPUT_VARIABLE(4 + (int)applyScale);
beta = INPUT_VARIABLE(3 + (int)applyScale);
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
}
@ -606,7 +682,7 @@ PLATFORM_IMPL(batchnorm_bp) {
if(beta != nullptr)
REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str());
// types of all input arrays should be the same (except dLdO)
// types of all input arrays should be the same
for(int i = 1; i < block.width() - 1; ++i)
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !");
@ -626,11 +702,19 @@ PLATFORM_IMPL(batchnorm_bp) {
(*weights)({1,2, 0,0}).assign(0);
}
*dLdM = 0;
*dLdV = 0;
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
input = new NDArray(input->permute(permut));
dLdO = new NDArray(dLdO->permute(permut));
dLdI = new NDArray(dLdI->permute(permut));
}
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW);
*dLdM = 0;
*dLdV = 0;
if(applyScale || applyOffset) {
if(applyScale)
dLdG->assign((*dLdW)({0,1, 0,0}));
@ -641,6 +725,12 @@ PLATFORM_IMPL(batchnorm_bp) {
delete dLdW;
}
if(axes[0] == inRank - 1 && inRank > 2) {
delete input;
delete dLdO;
delete dLdI;
}
return Status::OK();
}
@ -696,7 +786,7 @@ PLATFORM_CHECK(batchnorm_bp) {
const int inRank = input->rankOf();
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) &&
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 &&
dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32);

View File

@ -2901,31 +2901,29 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
delete result;
}
////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556, 0., 0.254556,0.509112, 0.763668, 1.018224, 1.272779,
1.527335, 1.781891, 2.036447, 2.291003,2.545559, 2.800115, 3.054671, 3.309227,3.563783, 3.818338, 4.072894, 4.32745}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
mean.assign(1.);
variance.assign(0.5);
variance.assign(0.46666667);
gamma.assign(1.2);
// beta.assign(1.); // has no effect on gradient calculations
beta.assign(1.); // has no effect on gradient calculations
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2945,20 +2943,22 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray mean ('c', {3}, {1.05, 1.1, 1.15});
NDArray variance('c', {3}, {0.5, 0.6, 0.7});
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4});
NDArray beta ('c', {3}, nd4j::DataType::DOUBLE);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742, 0., 0.251742,0.501992, 0.752989, 1.003985, 1.254981,
1.527335, 1.781891, 2.036447, 2.291003,2.517418, 2.76916 , 3.020902, 3.272644,3.513947, 3.764943, 4.015939, 4.266936});
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388});
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4});
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
// beta.assign(1.); // has no effect on gradient calculations
@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2989,17 +2989,18 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4});
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2});
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9});
NDArray beta ('c', {2,1,4}, nd4j::DataType::DOUBLE);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742, 0., 0.251556,0.509112, 0.755225, 1.003985, 1.25778 ,
1.517885, 1.784991, 2.05947 , 2.341504,2.529808, 2.804986, 3.089205, 3.382173,3.541731, 3.824981, 4.11894 , 4.422841});
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 });
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85});
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
// beta.assign(1.); // has no effect on gradient calculations
@ -3007,7 +3008,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3037,8 +3038,8 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {1.442483, 0.9502 , 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
@ -3046,7 +3047,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3076,8 +3077,9 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2}, {1.527335, 1.272779,1.018224, 0.763668,-0.466136, -0.233068,0., 0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169, 1.501697,1.716225, 1.930753,
-2.545559, -2.800115,-3.054671, -3.309227,3.262951, 3.496019,3.729087, 3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618, 4.934146,5.148675, 5.363203}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
@ -3086,7 +3088,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3116,8 +3118,9 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528, -0.509112, 0.699204, -0.885433, 1.072641, -1.527335, 1.631475, -1.770866, 1.930753,
-2.545559, 2.563747, -2.656298, 2.788865, -3.563783, 3.496019, -3.541731, 3.646978, -4.582006, 4.42829 , -4.427164, 4.50509 , -5.60023 , 5.360562, -5.312597, 5.363203}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
@ -3126,7 +3129,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3156,20 +3159,21 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) {
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584,0.509112, -0.233068, -0., 0.214528,-0.509112, 0.699204, -0.885433, 1.072641,-1.527335, 1.631475, -1.770866,
1.930753,-2.545559, 2.563747, -2.656298, 2.788865,-3.563783, 3.496019, -3.541731, 3.646978,-4.582006, 4.42829 , -4.427164,
4.50509 ,-5.60023 , 5.360562, -5.312597, 5.363203, -6.618453, 6.292834, -6.19803 , 6.221315,-7.636677, 7.225105, -7.083463,
7.079428,-8.6549 , 8.157377, -7.968895, 7.93754 ,-9.673124, 9.089649, -8.854328, 8.795652, -10.691348, 10.02192 , -9.739761,
9.653765,-11.709571, 10.954192, -10.625194, 10.511877,-12.727795, 11.886464, -11.510627, 11.36999 ,-13.746018, 12.818735, -12.39606 , 12.228102}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
gradO.linspace(-0.9, 0.15);
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3201,10 +3205,11 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2,2}, {1.527335, 1.272779, 1.018224, 0.763668, 0.509112, 0.254556, -0. , -0.254556, 0.466136, 0.699204, 0.932272, 1.16534 , 1.398407, 1.631475, 1.864543, 2.097611,
-2.213582, -2.43494 , -2.656298, -2.877657, -3.099015, -3.320373, -3.541731, -3.76309 , 3.861506, 4.076034, 4.290562, 4.50509 , 4.719618, 4.934146, 5.148675, 5.363203,
-6.618453, -6.873009, -7.127565, -7.382121, -7.636677, -7.891233, -8.145789, -8.400345, 7.924309, 8.157377, 8.390445, 8.623513, 8.856581, 9.089649, 9.322717, 9.555784,
-9.297045, -9.518403, -9.739761, -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405, 10.940933, 11.155462, 11.36999 , 11.584518, 11.799046, 12.013574, 12.228102}, nd4j::DataType::FLOAT32);
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
@ -3213,7 +3218,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
nd4j::ops::batchnorm_bp op;
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());