diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 5641bab43..8b6bd24bc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -88,8 +89,27 @@ CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta + // auto v = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); + // auto m = input->reduceAlongDimension(nd4j::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); + helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); + // NDArray stdInv = *v + epsilon; + // stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) + // stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + // if(applyScale) + // stdInv *= *gamma; + + // // empty array with same shape as input + // input->applyBroadcast(nd4j::broadcast::Subtract, axes, m, output); + // output->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv); + + // if(applyOffset) + // output->applyBroadcast(nd4j::broadcast::Add, axes, beta); + + // delete v; + // delete m; + return Status::OK(); } @@ -113,10 +133,9 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { NDArray* input = INPUT_VARIABLE(0); NDArray* mean = INPUT_VARIABLE(1); NDArray* variance = INPUT_VARIABLE(2); - NDArray* dLdO = INPUT_VARIABLE(3); // next epsilon NDArray* gamma = nullptr; NDArray* beta = nullptr; - + NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // next epsilon NDArray* dLdI = OUTPUT_VARIABLE(0); NDArray* dLdM = OUTPUT_VARIABLE(1); @@ -129,11 +148,11 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { const float epsilon = T_ARG(0); if(applyScale) { - gamma = INPUT_VARIABLE(4); + gamma = INPUT_VARIABLE(3); dLdG = OUTPUT_VARIABLE(3); } if(applyOffset) { - beta = INPUT_VARIABLE(4 + (int)applyScale); + beta = INPUT_VARIABLE(3 + (int)applyScale); dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); } @@ -172,67 +191,120 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); // types of all input arrays should be the same (except dLdO) - for(int i = 1; i < block.width() - 1; ++i) - if(i != 3) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); + for(int i = 1; i < block.width() - 2; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); // ***** calculations ***** // - // formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta + // notations: + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output + // g = dLdO + // stdInv = 1 / (v + eps)^0.5 + // N - batch size (product of spatial dimensions) - // consider mean and variance as constants (since we get them as inputs and don't calculate them) - // dLdI = (dLdO * gamma) / (variance + epsilon)^0.5 - // dLdV = (-0.5 * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5 - // dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5 - // dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5 - // dLdB = dLdO_sum + // derivatives: + // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) + + // dfdx = gamma*stdInv*g; + // dfdm = -gamma*stdInv*g_sum; + // dmdx = 1/N; + // dvdx = 2 * (x - m) / N + // dvdm = -2 * [(x - m)]_sum / N + // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience + + // finally: + // dLdI = gamma * ( stdInv * (g - g_sum/N) + (2/N) * dfdv * (dvdm/2 + (x - m)) ) + + // dLdG = (g * (x - m))_sum * stdInv + // dLdB = g_sum + + // variance = input->varianceAlongDimension(variance::SummaryStatsVariance, false, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); + // mean = input->reduceAlongDimension(nd4j::reduce::Mean, ShapeUtils::evalDimsToExclude(input->rankOf(), axes)); const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); - - NDArray temp1 = *variance + epsilon; - temp1.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) - auto temp2 = temp1.transform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 - if(applyScale) - temp2 *= *gamma; // gamma / (variance + epsilon)^0.5 - - NDArray temp3(input); // empty array with same shape as input - input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3); // input - mean - temp3 *= *dLdO; // (input - mean) * dLdO - const bool keepUnitiesInShape = inRank == mean->rankOf(); - // dLdI - dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI); + // inverse batch size 1/N + const float Ninv = 1.f * shape::tadLength(input->getShapeInfo(), axes.data(), axes.size()) / input->lengthOf(); - // dLdM - dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); // dLdO sum over excluded axes + // input - mean + NDArray xMinusMean(input); // empty array with same shape as input + input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); + + // stdInv + NDArray stdInv = *variance + epsilon; + stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + + // dvdm (use dLdM as storage for dvdm) + xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); + *dLdM *= -Ninv; + + // g_sum + auto gSum = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape); // dLdB if(applyOffset) - dLdB->assign(dLdM); + dLdB->assign(gSum); - // dLdM - // dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2); - // dLdM->applyTransform(nd4j::transform::Neg); - *dLdM = 0; // put zeros so far + // stdInv * (g - g_sum/N) (use dLdI as storage for this expression) + gSum *= Ninv; + dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, &gSum, dLdI); + dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv); - //dLdV - temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // ((input - mean) * dLdO)_sum + // dLdV <- [g*(x - m)]_sum + (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // dLdG - if(applyScale) { - dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG); - // dLdV->assign(dLdG); - dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma); - } - else - // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2); + *dLdV *= stdInv; + if(applyScale) + dLdG->assign(dLdV); - // dLdV - // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1); - // *dLdV *= -0.5; + // (2 / N) * dfdv (use dLdV as storage for dfdv) + *dLdV *= stdInv*stdInv; // dLdV*stdInv * stdInv^2 + *dLdV *= -Ninv; // -0.5f * (2 / N); + + // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dLdM); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV); + + // dLdI + *dLdI += xMinusMean; + if(applyScale) + dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma); + + *dLdM = 0; // put zeros so far *dLdV = 0; // put zeros so far + // java code + // NDArray std = *variance + epsilon; + // std.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) + // std.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + // NDArray xMu(input); + // input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMu); + // NDArray xHat(input); + // xMu.applyBroadcast(nd4j::broadcast::Multiply, axes, &std, &xHat); + // NDArray dxhat(input); + // dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma, &dxhat); + // NDArray temp = dxhat*xMu; + // temp.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); + // *dLdV *= -0.5f * std*std*std; + // NDArray* dxmu1 = dxhat.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); + // *dxmu1 *= -std; + // NDArray* dxmu2 = xMu.reduceAlongDimension(reduce::Sum, excludedAxes, keepUnitiesInShape); + // *dxmu2 *= *dLdV * (-2.f/N); + // NDArray dLdmu = *dxmu1 + *dxmu2; + // dLdmu *= (1.f /N); + // *dLdV *= (2.f/N); + // dxhat.applyBroadcast(nd4j::broadcast::Multiply, axes, &std); + // xMu.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV); + // dxhat += xMu; + // dxhat.applyBroadcast(nd4j::broadcast::Add, axes, &dLdmu, dLdI); + // delete dxmu1; + // delete dxmu2; + // xHat *= *dLdO; + // xHat.reduceAlongDimension(reduce::Sum, dLdG, excludedAxes, keepUnitiesInShape); + return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 13e1cfe11..27f836a0e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -55,7 +56,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; // indicate whether gamma or/and beta are given - auto flags = mkldnn::normalization_flags::use_global_stats; + auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch if (weights != nullptr) flags |= mkldnn::normalization_flags::use_scale_shift; @@ -182,7 +183,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; // indicate whether gamma or/and beta are given - auto flags = mkldnn::normalization_flags::use_global_stats; + auto flags = mkldnn::normalization_flags::use_global_stats; // don't calculate the mean and variance for each mini-batch if (weights != nullptr) flags |= mkldnn::normalization_flags::use_scale_shift; @@ -308,6 +309,70 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const stream.wait(); // shape::printArray(dLdI_mkl_mem.map_data(),8); + + // notations: + // f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output + // g = dLdO + // stdInv = 1 / (v + eps)^0.5 + // N - batch size (product of spatial dimensions) + + // formula for full derivative with respect to input (x) + // dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx) + + // !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!! + + // dfdm = -gamma*stdInv*g_sum; + // dmdx = 1/N; + // dvdx = 2 * (x - m) / N + // dvdm = -2 * [(x - m)]_sum / N + // dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience + + // finally: + // dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m)) + // dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) ) + + std::vector axes = {1}; + const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes); + + // inversed batch size 1 / N + const auto Ninv = 1.f * mean->lengthOf() / x->lengthOf(); + + // x - mean + NDArray xMinusMean(x); // empty array with same shape as x + const_cast(x)->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); + + // stdInv + NDArray stdInv = *variance + epsilon; + stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) + stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + + // dfdm / N + auto dfdm = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes); + dfdm *= stdInv; + dfdm *= -Ninv; + + // dvdm / 2 + NDArray dvdm(mean); // empty array with same shape as mean + xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, &dvdm, excludedAxes); + dvdm *= -Ninv; + + // (2/N)*dfdv + NDArray dfdv(variance); // empty array with same shape as variance + (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, &dfdv, excludedAxes); + dfdv *= stdInv*stdInv*stdInv; + dfdv *= -Ninv; + + // dvdm/2 + (x - m) + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dvdm); + // dfdv * (dvdm/2 + (x - m)) + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &dfdv); + // add dfdm / N + xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, &dfdm); + // * gamma + auto gamma = (*weights)({0,1, 0,0}); + xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, &gamma); + + *dLdI += xMinusMean; } PLATFORM_IMPL(batchnorm) { @@ -371,10 +436,21 @@ PLATFORM_IMPL(batchnorm) { (*weights)({1,2, 0,0}).assign(0); } + if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc + std::vector permut = inRank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); + input = new NDArray(input->permute(permut)); + output = new NDArray(output->permute(permut)); + } + batchnormMKLDNN(input, mean, variance, weights, epsilon, output); delete weights; + if(axes[0] == inRank - 1 && inRank > 2) { + delete input; + delete output; + } + return Status::OK(); } @@ -418,7 +494,7 @@ PLATFORM_CHECK(batchnorm) { const int inRank = input->rankOf(); - return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && + return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) && (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); } @@ -558,29 +634,29 @@ PLATFORM_CHECK(batchnorm) { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm_bp) { - NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw - NDArray* mean = INPUT_VARIABLE(1); // [c] - NDArray* variance = INPUT_VARIABLE(2); // [c] - NDArray* dLdO = INPUT_VARIABLE(3); // same as input - NDArray* gamma = nullptr; // [c] - NDArray* beta = nullptr; // [c] + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + NDArray* dLdO = INPUT_VARIABLE(block.width() - 1); // same as input - NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input - NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] - NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] - NDArray* dLdG = nullptr; // [c] - NDArray* dLdB = nullptr; // [c] + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] const bool applyScale = (bool)INT_ARG(0); const bool applyOffset = (bool)INT_ARG(1); const float epsilon = T_ARG(0); if(applyScale) { - gamma = INPUT_VARIABLE(4); + gamma = INPUT_VARIABLE(3); dLdG = OUTPUT_VARIABLE(3); } if(applyOffset) { - beta = INPUT_VARIABLE(4 + (int)applyScale); + beta = INPUT_VARIABLE(3 + (int)applyScale); dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); } @@ -606,7 +682,7 @@ PLATFORM_IMPL(batchnorm_bp) { if(beta != nullptr) REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); - // types of all input arrays should be the same (except dLdO) + // types of all input arrays should be the same for(int i = 1; i < block.width() - 1; ++i) REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !"); @@ -626,11 +702,19 @@ PLATFORM_IMPL(batchnorm_bp) { (*weights)({1,2, 0,0}).assign(0); } - *dLdM = 0; - *dLdV = 0; + + if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc + std::vector permut = inRank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); + input = new NDArray(input->permute(permut)); + dLdO = new NDArray(dLdO->permute(permut)); + dLdI = new NDArray(dLdI->permute(permut)); + } batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW); + *dLdM = 0; + *dLdV = 0; + if(applyScale || applyOffset) { if(applyScale) dLdG->assign((*dLdW)({0,1, 0,0})); @@ -641,6 +725,12 @@ PLATFORM_IMPL(batchnorm_bp) { delete dLdW; } + if(axes[0] == inRank - 1 && inRank > 2) { + delete input; + delete dLdO; + delete dLdI; + } + return Status::OK(); } @@ -696,7 +786,7 @@ PLATFORM_CHECK(batchnorm_bp) { const int inRank = input->rankOf(); - return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && + return block.isUseMKLDNN() && axes.size() == 1 && (axes[0] == 1 || axes[0] == inRank - 1) && (inRank == 2 || inRank == 4 || inRank == 5) && (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 4871c12e4..654d4bf2c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2901,31 +2901,29 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { delete result; } -//////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); NDArray variance('c', {4}, nd4j::DataType::FLOAT32); NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32); NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556, 0., 0.254556,0.509112, 0.763668, 1.018224, 1.272779, - 1.527335, 1.781891, 2.036447, 2.291003,2.545559, 2.800115, 3.054671, 3.309227,3.563783, 3.818338, 4.072894, 4.32745}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32); + NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, nd4j::DataType::FLOAT32); NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); + variance.assign(0.46666667); gamma.assign(1.2); - // beta.assign(1.); // has no effect on gradient calculations + beta.assign(1.); // has no effect on gradient calculations gradO.linspace(-0.9, 0.15); nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2945,20 +2943,22 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { delete results; } + //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { - NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE); - NDArray mean ('c', {3}, {1.05, 1.1, 1.15}); - NDArray variance('c', {3}, {0.5, 0.6, 0.7}); - NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}); - NDArray beta ('c', {3}, nd4j::DataType::DOUBLE); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE); + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, nd4j::DataType::FLOAT32); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {3}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742, 0., 0.251742,0.501992, 0.752989, 1.003985, 1.254981, - 1.527335, 1.781891, 2.036447, 2.291003,2.517418, 2.76916 , 3.020902, 3.272644,3.513947, 3.764943, 4.015939, 4.266936}); - NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}); - NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}); + NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, + 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, + -0.290863, -0.343746, -0.396631}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); // beta.assign(1.); // has no effect on gradient calculations @@ -2966,7 +2966,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2989,17 +2989,18 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { - NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE); - NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); - NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); - NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}); - NDArray beta ('c', {2,1,4}, nd4j::DataType::DOUBLE); - NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE); + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {2,1,4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742, 0., 0.251556,0.509112, 0.755225, 1.003985, 1.25778 , - 1.517885, 1.784991, 2.05947 , 2.341504,2.529808, 2.804986, 3.089205, 3.382173,3.541731, 3.824981, 4.11894 , 4.422841}); - NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }); - NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}); + NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, + 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, + -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); // beta.assign(1.); // has no effect on gradient calculations @@ -3007,7 +3008,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3037,8 +3038,8 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528}, nd4j::DataType::FLOAT32); - NDArray expdLdG('c', {4}, {1.442483, 0.9502 , 0.569207, 0.314641}, nd4j::DataType::FLOAT32); + NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, nd4j::DataType::FLOAT32); NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); @@ -3046,7 +3047,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3076,8 +3077,9 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,4,2,2}, {1.527335, 1.272779,1.018224, 0.763668,-0.466136, -0.233068,0., 0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169, 1.501697,1.716225, 1.930753, - -2.545559, -2.800115,-3.054671, -3.309227,3.262951, 3.496019,3.729087, 3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618, 4.934146,5.148675, 5.363203}, nd4j::DataType::FLOAT32); + NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, + -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, + -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, nd4j::DataType::FLOAT32); NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); @@ -3086,7 +3088,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3116,8 +3118,9 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528, -0.509112, 0.699204, -0.885433, 1.072641, -1.527335, 1.631475, -1.770866, 1.930753, - -2.545559, 2.563747, -2.656298, 2.788865, -3.563783, 3.496019, -3.541731, 3.646978, -4.582006, 4.42829 , -4.427164, 4.50509 , -5.60023 , 5.360562, -5.312597, 5.363203}, nd4j::DataType::FLOAT32); + NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, + 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, + -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, nd4j::DataType::FLOAT32); NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); @@ -3126,7 +3129,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3156,20 +3159,21 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584,0.509112, -0.233068, -0., 0.214528,-0.509112, 0.699204, -0.885433, 1.072641,-1.527335, 1.631475, -1.770866, - 1.930753,-2.545559, 2.563747, -2.656298, 2.788865,-3.563783, 3.496019, -3.541731, 3.646978,-4.582006, 4.42829 , -4.427164, - 4.50509 ,-5.60023 , 5.360562, -5.312597, 5.363203, -6.618453, 6.292834, -6.19803 , 6.221315,-7.636677, 7.225105, -7.083463, - 7.079428,-8.6549 , 8.157377, -7.968895, 7.93754 ,-9.673124, 9.089649, -8.854328, 8.795652, -10.691348, 10.02192 , -9.739761, - 9.653765,-11.709571, 10.954192, -10.625194, 10.511877,-12.727795, 11.886464, -11.510627, 11.36999 ,-13.746018, 12.818735, -12.39606 , 12.228102}, nd4j::DataType::FLOAT32); + NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, + -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, + 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, + -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, + -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, nd4j::DataType::FLOAT32); NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); gradO.linspace(-0.9, 0.15); + nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3201,10 +3205,11 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); - NDArray expdLdI('c', {2,4,2,2,2}, {1.527335, 1.272779, 1.018224, 0.763668, 0.509112, 0.254556, -0. , -0.254556, 0.466136, 0.699204, 0.932272, 1.16534 , 1.398407, 1.631475, 1.864543, 2.097611, - -2.213582, -2.43494 , -2.656298, -2.877657, -3.099015, -3.320373, -3.541731, -3.76309 , 3.861506, 4.076034, 4.290562, 4.50509 , 4.719618, 4.934146, 5.148675, 5.363203, - -6.618453, -6.873009, -7.127565, -7.382121, -7.636677, -7.891233, -8.145789, -8.400345, 7.924309, 8.157377, 8.390445, 8.623513, 8.856581, 9.089649, 9.322717, 9.555784, - -9.297045, -9.518403, -9.739761, -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405, 10.940933, 11.155462, 11.36999 , 11.584518, 11.799046, 12.013574, 12.228102}, nd4j::DataType::FLOAT32); + NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, + 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, + -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, + 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, + 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, nd4j::DataType::FLOAT32); NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); @@ -3213,7 +3218,7 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { nd4j::ops::batchnorm_bp op; - auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); ASSERT_EQ(ND4J_STATUS_OK, results->status());