parent
cd961727bb
commit
0eda1e733e
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -88,8 +89,27 @@ CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
|
|||
nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0);
|
||||
|
||||
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
||||
// auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
|
||||
// auto m = input->reduceAlongDimension(nd4j::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
|
||||
|
||||
helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon);
|
||||
|
||||
// NDArray stdInv = *v + epsilon;
|
||||
// stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
|
||||
// stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
|
||||
// if(applyScale)
|
||||
// stdInv *= *gamma;
|
||||
|
||||
// // empty array with same shape as input
|
||||
// input->applyBroadcast(nd4j::broadcast::Subtract, axes, m, output);
|
||||
// output->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv);
|
||||
|
||||
// if(applyOffset)
|
||||
// output->applyBroadcast(nd4j::broadcast::Add, axes, beta);
|
||||
|
||||
// delete v;
|
||||
// delete m;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -113,10 +133,9 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
|||
NDArray* input = INPUT_VARIABLE(0);
|
||||
NDArray* mean = INPUT_VARIABLE(1);
|
||||
NDArray* variance = INPUT_VARIABLE(2);
|
||||
NDArray* dLdO = INPUT_VARIABLE(3); // next epsilon
|
||||
NDArray* gamma = nullptr;
|
||||
NDArray* beta = nullptr;
|
||||
|
||||
NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // next epsilon
|
||||
|
||||
NDArray* dLdI = OUTPUT_VARIABLE(0);
|
||||
NDArray* dLdM = OUTPUT_VARIABLE(1);
|
||||
|
@ -129,11 +148,11 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
|||
const float epsilon = T_ARG(0);
|
||||
|
||||
if(applyScale) {
|
||||
gamma = INPUT_VARIABLE(4);
|
||||
gamma = INPUT_VARIABLE(3);
|
||||
dLdG = OUTPUT_VARIABLE(3);
|
||||
}
|
||||
if(applyOffset) {
|
||||
beta = INPUT_VARIABLE(4 + (int)applyScale);
|
||||
beta = INPUT_VARIABLE(3 + (int)applyScale);
|
||||
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
|
||||
}
|
||||
|
||||
|
@ -172,67 +191,120 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
|||
REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str());
|
||||
|
||||
// types of all input arrays should be the same (except dLdO)
|
||||
for(int i = 1; i < block.width() - 1; ++i)
|
||||
if(i != 3)
|
||||
for(int i = 1; i < block.width() - 2; ++i)
|
||||
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
|
||||
|
||||
// ***** calculations ***** //
|
||||
|
||||
// formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
||||
// notations:
|
||||
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
|
||||
// g = dLdO
|
||||
// stdInv = 1 / (v + eps)^0.5
|
||||
// N - batch size (product of spatial dimensions)
|
||||
|
||||
// consider mean and variance as constants (since we get them as inputs and don't calculate them)
|
||||
// dLdI = (dLdO * gamma) / (variance + epsilon)^0.5
|
||||
// dLdV = (-0.5 * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5
|
||||
// dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5
|
||||
// dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5
|
||||
// dLdB = dLdO_sum
|
||||
// derivatives:
|
||||
// dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)
|
||||
|
||||
// dfdx = gamma*stdInv*g;
|
||||
// dfdm = -gamma*stdInv*g_sum;
|
||||
// dmdx = 1/N;
|
||||
// dvdx = 2 * (x - m) / N
|
||||
// dvdm = -2 * [(x - m)]_sum / N
|
||||
// dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience
|
||||
|
||||
// finally:
|
||||
// dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - m)) )
|
||||
|
||||
// dLdG = (g * (x - m))_sum * stdInv
|
||||
// dLdB = g_sum
|
||||
|
||||
// variance = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
|
||||
// mean = input->reduceAlongDimension(nd4j::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes));
|
||||
|
||||
const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes);
|
||||
|
||||
NDArray temp1 = *variance + epsilon;
|
||||
temp1.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
|
||||
auto temp2 = temp1.transform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
|
||||
if(applyScale)
|
||||
temp2 *= *gamma; // gamma / (variance + epsilon)^0.5
|
||||
|
||||
NDArray temp3(input); // empty array with same shape as input
|
||||
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3); // input - mean
|
||||
temp3 *= *dLdO; // (input - mean) * dLdO
|
||||
|
||||
const bool keepUnitiesInShape = inRank == mean->rankOf();
|
||||
|
||||
// dLdI
|
||||
dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI);
|
||||
// inverse batch size 1/N
|
||||
const float Ninv = 1.f * shape::tadLength(input->getShapeInfo(), axes.data(), axes.size()) / input->lengthOf();
|
||||
|
||||
// dLdM
|
||||
dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); // dLdO sum over excluded axes
|
||||
// input - mean
|
||||
NDArray xMinusMean(input); // empty array with same shape as input
|
||||
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean);
|
||||
|
||||
// stdInv
|
||||
NDArray stdInv = *variance + epsilon;
|
||||
stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
|
||||
stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
|
||||
|
||||
// dvdm (use dLdM as storage for dvdm)
|
||||
xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape);
|
||||
*dLdM *= -Ninv;
|
||||
|
||||
// g_sum
|
||||
auto gSum = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape);
|
||||
|
||||
// dLdB
|
||||
if(applyOffset)
|
||||
dLdB->assign(dLdM);
|
||||
dLdB->assign(gSum);
|
||||
|
||||
// dLdM
|
||||
// dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
|
||||
// dLdM->applyTransform(nd4j::transform::Neg);
|
||||
*dLdM = 0; // put zeros so far
|
||||
// stdInv * (g - g_sum/N) (use dLdI as storage for this expression)
|
||||
gSum *= Ninv;
|
||||
dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, &gSum, dLdI);
|
||||
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv);
|
||||
|
||||
//dLdV
|
||||
temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // ((input - mean) * dLdO)_sum
|
||||
// dLdV <- [g*(x - m)]_sum
|
||||
(xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape);
|
||||
|
||||
// dLdG
|
||||
if(applyScale) {
|
||||
dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG);
|
||||
// dLdV->assign(dLdG);
|
||||
dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma);
|
||||
}
|
||||
else
|
||||
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
|
||||
*dLdV *= stdInv;
|
||||
if(applyScale)
|
||||
dLdG->assign(dLdV);
|
||||
|
||||
// dLdV
|
||||
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1);
|
||||
// *dLdV *= -0.5;
|
||||
// (2 / N) * dfdv (use dLdV as storage for dfdv)
|
||||
*dLdV *= stdInv*stdInv; // dLdV*stdInv * stdInv^2
|
||||
*dLdV *= -Ninv; // -0.5f * (2 / N);
|
||||
|
||||
// dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression)
|
||||
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dLdM);
|
||||
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV);
|
||||
|
||||
// dLdI
|
||||
*dLdI += xMinusMean;
|
||||
if(applyScale)
|
||||
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma);
|
||||
|
||||
*dLdM = 0; // put zeros so far
|
||||
*dLdV = 0; // put zeros so far
|
||||
|
||||
// java code
|
||||
// NDArray std = *variance + epsilon;
|
||||
// std.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
|
||||
// std.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
|
||||
// NDArray xMu(input);
|
||||
// input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMu);
|
||||
// NDArray xHat(input);
|
||||
// xMu.applyBroadcast(nd4j::broadcast::Multiply, axes, &std, &xHat);
|
||||
// NDArray dxhat(input);
|
||||
// dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma, &dxhat);
|
||||
// NDArray temp = dxhat*xMu;
|
||||
// temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape);
|
||||
// *dLdV *= -0.5f * std*std*std;
|
||||
// NDArray* dxmu1 = dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape);
|
||||
// *dxmu1 *= -std;
|
||||
// NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape);
|
||||
// *dxmu2 *= *dLdV * (-2.f/N);
|
||||
// NDArray dLdmu = *dxmu1 + *dxmu2;
|
||||
// dLdmu *= (1.f /N);
|
||||
// *dLdV *= (2.f/N);
|
||||
// dxhat.applyBroadcast(nd4j::broadcast::Multiply, axes, &std);
|
||||
// xMu.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV);
|
||||
// dxhat += xMu;
|
||||
// dxhat.applyBroadcast(nd4j::broadcast::Add, axes, &dLdmu, dLdI);
|
||||
// delete dxmu1;
|
||||
// delete dxmu2;
|
||||
// xHat *= *dLdO;
|
||||
// xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, keepUnitiesInShape);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -55,7 +56,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
|
||||
|
||||
// indicate whether gamma or/and beta are given
|
||||
auto flags = mkldnn::normalization_flags::use_global_stats;
|
||||
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
|
||||
if (weights != nullptr)
|
||||
flags |= mkldnn::normalization_flags::use_scale_shift;
|
||||
|
||||
|
@ -182,7 +183,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
|
||||
|
||||
// indicate whether gamma or/and beta are given
|
||||
auto flags = mkldnn::normalization_flags::use_global_stats;
|
||||
auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch
|
||||
if (weights != nullptr)
|
||||
flags |= mkldnn::normalization_flags::use_scale_shift;
|
||||
|
||||
|
@ -308,6 +309,70 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
stream.wait();
|
||||
|
||||
// shape::printArray(dLdI_mkl_mem.map_data<float>(),8);
|
||||
|
||||
// notations:
|
||||
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
|
||||
// g = dLdO
|
||||
// stdInv = 1 / (v + eps)^0.5
|
||||
// N - batch size (product of spatial dimensions)
|
||||
|
||||
// formula for full derivative with respect to input (x)
|
||||
// dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)
|
||||
|
||||
// !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!!
|
||||
|
||||
// dfdm = -gamma*stdInv*g_sum;
|
||||
// dmdx = 1/N;
|
||||
// dvdx = 2 * (x - m) / N
|
||||
// dvdm = -2 * [(x - m)]_sum / N
|
||||
// dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience
|
||||
|
||||
// finally:
|
||||
// dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m))
|
||||
// dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) )
|
||||
|
||||
std::vector<int> axes = {1};
|
||||
const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes);
|
||||
|
||||
// inversed batch size 1 / N
|
||||
const auto Ninv = 1.f * mean->lengthOf() / x->lengthOf();
|
||||
|
||||
// x - mean
|
||||
NDArray xMinusMean(x); // empty array with same shape as x
|
||||
const_cast<NDArray*>(x)->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean);
|
||||
|
||||
// stdInv
|
||||
NDArray stdInv = *variance + epsilon;
|
||||
stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
|
||||
stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
|
||||
|
||||
// dfdm / N
|
||||
auto dfdm = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes);
|
||||
dfdm *= stdInv;
|
||||
dfdm *= -Ninv;
|
||||
|
||||
// dvdm / 2
|
||||
NDArray dvdm(mean); // empty array with same shape as mean
|
||||
xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, &dvdm, excludedAxes);
|
||||
dvdm *= -Ninv;
|
||||
|
||||
// (2/N)*dfdv
|
||||
NDArray dfdv(variance); // empty array with same shape as variance
|
||||
(xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, &dfdv, excludedAxes);
|
||||
dfdv *= stdInv*stdInv*stdInv;
|
||||
dfdv *= -Ninv;
|
||||
|
||||
// dvdm/2 + (x - m)
|
||||
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dvdm);
|
||||
// dfdv * (dvdm/2 + (x - m))
|
||||
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &dfdv);
|
||||
// add dfdm / N
|
||||
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dfdm);
|
||||
// * gamma
|
||||
auto gamma = (*weights)({0,1, 0,0});
|
||||
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &gamma);
|
||||
|
||||
*dLdI += xMinusMean;
|
||||
}
|
||||
|
||||
PLATFORM_IMPL(batchnorm) {
|
||||
|
@ -371,10 +436,21 @@ PLATFORM_IMPL(batchnorm) {
|
|||
(*weights)({1,2, 0,0}).assign(0);
|
||||
}
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
|
||||
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||
input = new NDArray(input->permute(permut));
|
||||
output = new NDArray(output->permute(permut));
|
||||
}
|
||||
|
||||
batchnormMKLDNN(input, mean, variance, weights, epsilon, output);
|
||||
|
||||
delete weights;
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -418,7 +494,7 @@ PLATFORM_CHECK(batchnorm) {
|
|||
|
||||
const int inRank = input->rankOf();
|
||||
|
||||
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
|
||||
return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) &&
|
||||
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
|
||||
gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32);
|
||||
}
|
||||
|
@ -561,9 +637,9 @@ PLATFORM_IMPL(batchnorm_bp) {
|
|||
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
||||
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
||||
NDArray* dLdO = INPUT_VARIABLE(3); // same as input
|
||||
NDArray* gamma = nullptr; // [c]
|
||||
NDArray* beta = nullptr; // [c]
|
||||
NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // same as input
|
||||
|
||||
NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input
|
||||
NDArray* dLdM = OUTPUT_VARIABLE(1); // [c]
|
||||
|
@ -576,11 +652,11 @@ PLATFORM_IMPL(batchnorm_bp) {
|
|||
const float epsilon = T_ARG(0);
|
||||
|
||||
if(applyScale) {
|
||||
gamma = INPUT_VARIABLE(4);
|
||||
gamma = INPUT_VARIABLE(3);
|
||||
dLdG = OUTPUT_VARIABLE(3);
|
||||
}
|
||||
if(applyOffset) {
|
||||
beta = INPUT_VARIABLE(4 + (int)applyScale);
|
||||
beta = INPUT_VARIABLE(3 + (int)applyScale);
|
||||
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
|
||||
}
|
||||
|
||||
|
@ -606,7 +682,7 @@ PLATFORM_IMPL(batchnorm_bp) {
|
|||
if(beta != nullptr)
|
||||
REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str());
|
||||
|
||||
// types of all input arrays should be the same (except dLdO)
|
||||
// types of all input arrays should be the same
|
||||
for(int i = 1; i < block.width() - 1; ++i)
|
||||
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !");
|
||||
|
||||
|
@ -626,11 +702,19 @@ PLATFORM_IMPL(batchnorm_bp) {
|
|||
(*weights)({1,2, 0,0}).assign(0);
|
||||
}
|
||||
|
||||
*dLdM = 0;
|
||||
*dLdV = 0;
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
|
||||
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||
input = new NDArray(input->permute(permut));
|
||||
dLdO = new NDArray(dLdO->permute(permut));
|
||||
dLdI = new NDArray(dLdI->permute(permut));
|
||||
}
|
||||
|
||||
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW);
|
||||
|
||||
*dLdM = 0;
|
||||
*dLdV = 0;
|
||||
|
||||
if(applyScale || applyOffset) {
|
||||
if(applyScale)
|
||||
dLdG->assign((*dLdW)({0,1, 0,0}));
|
||||
|
@ -641,6 +725,12 @@ PLATFORM_IMPL(batchnorm_bp) {
|
|||
delete dLdW;
|
||||
}
|
||||
|
||||
if(axes[0] == inRank - 1 && inRank > 2) {
|
||||
delete input;
|
||||
delete dLdO;
|
||||
delete dLdI;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -696,7 +786,7 @@ PLATFORM_CHECK(batchnorm_bp) {
|
|||
|
||||
const int inRank = input->rankOf();
|
||||
|
||||
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
|
||||
return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) &&
|
||||
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
|
||||
dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 &&
|
||||
dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32);
|
||||
|
|
|
@ -2901,31 +2901,29 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556, 0., 0.254556,0.509112, 0.763668, 1.018224, 1.272779,
|
||||
1.527335, 1.781891, 2.036447, 2.291003,2.545559, 2.800115, 3.054671, 3.309227,3.563783, 3.818338, 4.072894, 4.32745}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
mean.assign(1.);
|
||||
variance.assign(0.5);
|
||||
variance.assign(0.46666667);
|
||||
gamma.assign(1.2);
|
||||
// beta.assign(1.); // has no effect on gradient calculations
|
||||
beta.assign(1.); // has no effect on gradient calculations
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
@ -2945,20 +2943,22 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray mean ('c', {3}, {1.05, 1.1, 1.15});
|
||||
NDArray variance('c', {3}, {0.5, 0.6, 0.7});
|
||||
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4});
|
||||
NDArray beta ('c', {3}, nd4j::DataType::DOUBLE);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {3}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742, 0., 0.251742,0.501992, 0.752989, 1.003985, 1.254981,
|
||||
1.527335, 1.781891, 2.036447, 2.291003,2.517418, 2.76916 , 3.020902, 3.272644,3.513947, 3.764943, 4.015939, 4.266936});
|
||||
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388});
|
||||
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4});
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747,
|
||||
0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978,
|
||||
-0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
// beta.assign(1.); // has no effect on gradient calculations
|
||||
|
@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
|
|||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
@ -2989,17 +2989,18 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4});
|
||||
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2});
|
||||
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9});
|
||||
NDArray beta ('c', {2,1,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742, 0., 0.251556,0.509112, 0.755225, 1.003985, 1.25778 ,
|
||||
1.517885, 1.784991, 2.05947 , 2.341504,2.529808, 2.804986, 3.089205, 3.382173,3.541731, 3.824981, 4.11894 , 4.422841});
|
||||
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 });
|
||||
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85});
|
||||
NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002,
|
||||
0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000,
|
||||
-0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
// beta.assign(1.); // has no effect on gradient calculations
|
||||
|
@ -3007,7 +3008,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
|||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
@ -3037,8 +3038,8 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
|
|||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {1.442483, 0.9502 , 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
|
@ -3046,7 +3047,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
|
|||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
@ -3076,8 +3077,9 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
|
|||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4,2,2}, {1.527335, 1.272779,1.018224, 0.763668,-0.466136, -0.233068,0., 0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169, 1.501697,1.716225, 1.930753,
|
||||
-2.545559, -2.800115,-3.054671, -3.309227,3.262951, 3.496019,3.729087, 3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618, 4.934146,5.148675, 5.363203}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243,
|
||||
-1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118,
|
||||
-0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
@ -3086,7 +3088,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
|
|||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
@ -3116,8 +3118,9 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
|
|||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528, -0.509112, 0.699204, -0.885433, 1.072641, -1.527335, 1.631475, -1.770866, 1.930753,
|
||||
-2.545559, 2.563747, -2.656298, 2.788865, -3.563783, 3.496019, -3.541731, 3.646978, -4.582006, 4.42829 , -4.427164, 4.50509 , -5.60023 , 5.360562, -5.312597, 5.363203}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295,
|
||||
0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295,
|
||||
-0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
@ -3126,7 +3129,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
|
|||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
@ -3156,20 +3159,21 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) {
|
|||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584,0.509112, -0.233068, -0., 0.214528,-0.509112, 0.699204, -0.885433, 1.072641,-1.527335, 1.631475, -1.770866,
|
||||
1.930753,-2.545559, 2.563747, -2.656298, 2.788865,-3.563783, 3.496019, -3.541731, 3.646978,-4.582006, 4.42829 , -4.427164,
|
||||
4.50509 ,-5.60023 , 5.360562, -5.312597, 5.363203, -6.618453, 6.292834, -6.19803 , 6.221315,-7.636677, 7.225105, -7.083463,
|
||||
7.079428,-8.6549 , 8.157377, -7.968895, 7.93754 ,-9.673124, 9.089649, -8.854328, 8.795652, -10.691348, 10.02192 , -9.739761,
|
||||
9.653765,-11.709571, 10.954192, -10.625194, 10.511877,-12.727795, 11.886464, -11.510627, 11.36999 ,-13.746018, 12.818735, -12.39606 , 12.228102}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142,
|
||||
-43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662,
|
||||
15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032,
|
||||
-15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788,
|
||||
-27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input.linspace(0.1, 0.1);
|
||||
gradO.linspace(-0.9, 0.15);
|
||||
|
||||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
@ -3201,10 +3205,11 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
|
|||
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expdLdI('c', {2,4,2,2,2}, {1.527335, 1.272779, 1.018224, 0.763668, 0.509112, 0.254556, -0. , -0.254556, 0.466136, 0.699204, 0.932272, 1.16534 , 1.398407, 1.631475, 1.864543, 2.097611,
|
||||
-2.213582, -2.43494 , -2.656298, -2.877657, -3.099015, -3.320373, -3.541731, -3.76309 , 3.861506, 4.076034, 4.290562, 4.50509 , 4.719618, 4.934146, 5.148675, 5.363203,
|
||||
-6.618453, -6.873009, -7.127565, -7.382121, -7.636677, -7.891233, -8.145789, -8.400345, 7.924309, 8.157377, 8.390445, 8.623513, 8.856581, 9.089649, 9.322717, 9.555784,
|
||||
-9.297045, -9.518403, -9.739761, -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405, 10.940933, 11.155462, 11.36999 , 11.584518, 11.799046, 12.013574, 12.228102}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301,
|
||||
32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767,
|
||||
-27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526,
|
||||
30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773,
|
||||
31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
|
||||
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
@ -3213,7 +3218,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
|
|||
|
||||
nd4j::ops::batchnorm_bp op;
|
||||
|
||||
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
|
|
Loading…
Reference in New Issue