From 029a69a83594e8cc48f04d886dfc8a59096c80d8 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Sat, 26 Oct 2019 14:14:21 +0300 Subject: [PATCH] Shyrma bn mkl bp (#14) * - write code for new batchnorm backprop Signed-off-by: Yurii * - testing batchnorm backprop Signed-off-by: Yurii * - write code for batchnorm backprop based on mkl dnn api Signed-off-by: Yurii * - testing and fixing bugs in batchnorm_bp mkl dnn Signed-off-by: Yurii * - made corrections required by reviewer Signed-off-by: Yurii * - change name in java wrapper for batchnorm op Signed-off-by: Yurii --- libnd4j/include/helpers/ConstantShapeHelper.h | 1 + .../helpers/cpu/ConstantShapeHelper.cpp | 4 + .../helpers/cuda/ConstantShapeHelper.cu | 4 + .../ops/declarable/generic/nn/batchnorm.cpp | 480 ++++------- libnd4j/include/ops/declarable/headers/nn.h | 51 +- .../ops/declarable/helpers/cpu/batchnorm.cpp | 2 + .../declarable/platform/mkldnn/batchnorm.cpp | 787 +++++++++++++++--- .../declarable/platform/mkldnn/lstmLayer.cpp | 2 - .../platform/mkldnn/mkldnnUtils.cpp | 80 +- .../declarable/platform/mkldnn/mkldnnUtils.h | 4 +- .../benchmarking/impl/FullBenchmarkSuit.cpp | 4 +- .../layers_tests/DeclarableOpsTests1.cpp | 123 --- .../layers_tests/DeclarableOpsTests10.cpp | 97 ++- .../layers_tests/DeclarableOpsTests9.cpp | 336 +++++++- .../tests_cpu/layers_tests/MklDnnTests.cpp | 2 +- .../impl/layers/convolution/BatchNorm.java | 2 +- 16 files changed, 1280 insertions(+), 699 deletions(-) diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index 585db0198..d5ea9abe9 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -60,6 +60,7 @@ namespace nd4j { Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor); Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector &shape); Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); + Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo); Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, nd4j::memory::Workspace *workspace); Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true); diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index 531b68004..bcedd727e 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -99,6 +99,10 @@ namespace nd4j { return bufferForShapeInfo(descriptor).primaryAsT(); } + Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); + } + Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) { auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); return bufferForShapeInfo(descriptor).primaryAsT(); diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index 4004b9895..aae62594c 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -102,6 +102,10 @@ namespace nd4j { return bufferForShapeInfo(descriptor).primaryAsT(); } + Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) { + return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast(shapeInfo))); + } + Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) { auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); return bufferForShapeInfo(descriptor).primaryAsT(); diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 6ef4a49d5..5641bab43 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -29,84 +29,8 @@ namespace nd4j { namespace ops { -CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - - // FIXME: double? - const double epsilon = T_ARG(0); - - if(applyScale) - gamma = INPUT_VARIABLE(3); - if(applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); - - std::vector inArrs(block.width()); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); - - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !"); - - // normalized output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - - auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); - if(applyScale) - sigmaInvGam *= *gamma; - - NDArray inputMinusMean; - if(!input->isSameShape(output) && !mean->isSameShape(output)) { - auto inputTiled = NDArray(output, false, block.launchContext()); - input->tile(inputTiled); - inputMinusMean = inputTiled - *mean; - } - else - inputMinusMean = *input - *mean; - - if (applyOffset) - output->assign(inputMinusMean * sigmaInvGam + *beta); - else - output->assign(inputMinusMean * sigmaInvGam); - - return Status::OK(); -} - - DECLARE_TYPES(batchnorm) { - getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - ////////////////////////////////////////////////////////////////////////// -DECLARE_SHAPE_FN(batchnorm) { - - std::vector inArrs(block.width()); - auto in = inputShape->at(0); - for(int i = 0; i < block.width(); ++i) - inArrs[i] = INPUT_VARIABLE(i); - - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !"); - - auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, DataTypeUtils::pickFloatingType(ArrayOptions::dataType(in)))); - return SHAPELIST(result); -} - -////////////////////////////////////////////////////////////////////////// -CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { +CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { auto input = INPUT_VARIABLE(0); auto mean = INPUT_VARIABLE(1); @@ -123,7 +47,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { if(applyScale) gamma = INPUT_VARIABLE(3); if(applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + beta = INPUT_VARIABLE(3 + (int)applyScale); const int numOfIntArgs = block.getIArguments()->size(); const int inRank = input->rankOf(); @@ -137,30 +61,31 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { axes.push_back(inRank-1); // default dimension to reduce along is last dimension const int numOfAxes = axes.size(); - REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_NEW op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); - - // get, for example, something like {1, inDim1, 1, inDim3, 1} if axes = {1, 3} - std::vector expShapeWithUnities(inRank, 1); - for(int i = 0; i < numOfAxes; ++i) - expShapeWithUnities[axes[i]] = input->sizeAt(axes[i]); + REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} - std::vector expShape = numOfAxes == 1 ? std::vector(1, input->sizeAt(axes[0])) : expShapeWithUnities; - std::string expShapeStr = ShapeUtils::shapeAsString(expShape); + std::vector expShape; + if(numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} + expShape = std::vector(inRank, 1); + for(uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } - REQUIRE_TRUE(ShapeUtils::shapeAsString(mean) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of mean array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(mean).c_str()); - REQUIRE_TRUE(ShapeUtils::shapeAsString(variance) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of variance array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(variance).c_str()); + REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); if(gamma) - REQUIRE_TRUE(ShapeUtils::shapeAsString(gamma) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of gamma array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(gamma).c_str()); + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); if(beta) - REQUIRE_TRUE(ShapeUtils::shapeAsString(beta) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of beta array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(beta).c_str()); + REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); // types of all input arrays should be the same for(int i = 1; i < block.width(); ++i) - REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !"); + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM op: types of all input arrays should be the same !"); - nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0); + nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); @@ -168,15 +93,15 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { return Status::OK(); } -DECLARE_TYPES(batchnorm_new) { +DECLARE_TYPES(batchnorm) { getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); } -DECLARE_SHAPE_FN(batchnorm_new) { +DECLARE_SHAPE_FN(batchnorm) { auto inShapeInfo = inputShape->at(0); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - + auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.getWorkspace()); // output shape is identical to input shape return SHAPELIST(CONSTANT(outShapeInfo)); @@ -184,290 +109,177 @@ DECLARE_SHAPE_FN(batchnorm_new) { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; - NDArray *dLdO = nullptr; // next epsilon - auto dLdI = OUTPUT_VARIABLE(0); - auto dLdM = OUTPUT_VARIABLE(1); - auto dLdV = OUTPUT_VARIABLE(2); - NDArray *dLdG = nullptr; - NDArray *dLdB = nullptr; + NDArray* input = INPUT_VARIABLE(0); + NDArray* mean = INPUT_VARIABLE(1); + NDArray* variance = INPUT_VARIABLE(2); + NDArray* dLdO = INPUT_VARIABLE(3); // next epsilon + NDArray* gamma = nullptr; + NDArray* beta = nullptr; - const bool applyScale = (bool)INT_ARG(0); - const bool applyOffset = (bool)INT_ARG(1); - // FIXME: double? - const double epsilon = T_ARG(0); + NDArray* dLdI = OUTPUT_VARIABLE(0); + NDArray* dLdM = OUTPUT_VARIABLE(1); + NDArray* dLdV = OUTPUT_VARIABLE(2); + NDArray* dLdG = nullptr; + NDArray* dLdB = nullptr; - const int dLdONum = static_cast(applyScale) + static_cast(applyOffset); + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); if(applyScale) { - gamma = INPUT_VARIABLE(3); + gamma = INPUT_VARIABLE(4); dLdG = OUTPUT_VARIABLE(3); } if(applyOffset) { - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); - dLdB = OUTPUT_VARIABLE(3 + static_cast(applyScale)); + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); } - - dLdO = INPUT_VARIABLE(3 + dLdONum); - - std::vector inArrs(block.width()); - for(int i = 0; i < 4 + dLdONum; ++i) - inArrs[i] = INPUT_VARIABLE(i); - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !"); + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); + + // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes + // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5} + std::vector expShape; + if(numOfAxes == 1) + expShape.push_back(input->sizeAt(axes[0])); + else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3} + expShape = std::vector(inRank, 1); + for(uint i = 0; i < numOfAxes; ++i) + expShape[axes[i]] = input->sizeAt(axes[i]); + } + + REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma) + REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta) + REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); + + REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); + + // types of all input arrays should be the same (except dLdO) + for(int i = 1; i < block.width() - 1; ++i) + if(i != 3) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); // ***** calculations ***** // - auto sigmaInv = (*variance + epsilon).transform(transform::RSqrt); - - NDArray sigmaInvGamdLdO = -sigmaInv * *dLdO; - if(applyScale) - sigmaInvGamdLdO *= *gamma; + // formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta - NDArray inputMinusMean; - if(!input->isSameShape(dLdO) && !mean->isSameShape(dLdO)) { - auto inputTiled = NDArray(dLdO, false, block.launchContext()); - input->tile(inputTiled); - inputMinusMean = inputTiled - *mean; - } - else - inputMinusMean = *input - *mean; + // consider mean and variance as constants (since we get them as inputs and don't calculate them) + // dLdI = (dLdO * gamma) / (variance + epsilon)^0.5 + // dLdV = (-0.5 * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5 + // dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5 + // dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5 + // dLdB = dLdO_sum + + const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); + + NDArray temp1 = *variance + epsilon; + temp1.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) + auto temp2 = temp1.transform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 + if(applyScale) + temp2 *= *gamma; // gamma / (variance + epsilon)^0.5 + + NDArray temp3(input); // empty array with same shape as input + input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3); // input - mean + temp3 *= *dLdO; // (input - mean) * dLdO + + const bool keepUnitiesInShape = inRank == mean->rankOf(); // dLdI - if(!dLdI->isSameShape(dLdO)) - dLdI->assign( (-sigmaInvGamdLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdI->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdI->assign(-sigmaInvGamdLdO); + dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI); // dLdM - if(!dLdM->isSameShape(dLdO)) - dLdM->assign( sigmaInvGamdLdO.reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdM->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdM->assign(sigmaInvGamdLdO); + dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); // dLdO sum over excluded axes - // dLdV - if(!dLdV->isSameShape(dLdO)) { - dLdV->assign( (sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdV->getShapeInfo(), dLdO->getShapeInfo())) ); - } - else - dLdV->assign(sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f); + // dLdB + if(applyOffset) + dLdB->assign(dLdM); + + // dLdM + // dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2); + // dLdM->applyTransform(nd4j::transform::Neg); + *dLdM = 0; // put zeros so far + + //dLdV + temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // ((input - mean) * dLdO)_sum // dLdG if(applyScale) { - if(!dLdG->isSameShape(dLdO)) - dLdG->assign( (sigmaInv * inputMinusMean * *dLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdG->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdG->assign(sigmaInv * inputMinusMean * *dLdO); + dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG); + // dLdV->assign(dLdG); + dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma); } + else + // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2); - // dLdB - if(applyOffset) { - if(!dLdB->isSameShape(dLdO)) - dLdB->assign(dLdO->reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdB->getShapeInfo(), dLdO->getShapeInfo())) ); - else - dLdB->assign(dLdO); - } + // dLdV + // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1); + // *dLdV *= -0.5; + *dLdV = 0; // put zeros so far return Status::OK(); } - DECLARE_TYPES(batchnorm_bp) { - getOpDescriptor() - ->setAllowedInputTypes(0, nd4j::DataType::ANY) - ->setAllowedInputTypes(1, nd4j::DataType::ANY) - ->setAllowedInputTypes(2, nd4j::DataType::ANY) - ->setAllowedInputTypes(3, nd4j::DataType::ANY) - ->setAllowedInputTypes(4, nd4j::DataType::ANY) - ->setAllowedInputTypes(5, {ALL_FLOATS}) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(batchnorm_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, nd4j::DataType::ANY) + ->setAllowedInputTypes(1, nd4j::DataType::ANY) + ->setAllowedInputTypes(2, nd4j::DataType::ANY) + ->setAllowedInputTypes(3, {ALL_FLOATS}) + ->setAllowedInputTypes(4, nd4j::DataType::ANY) + ->setAllowedInputTypes(5, nd4j::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} ////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(batchnorm_bp) { + Nd4jLong* inShapeInfo = inputShape->at(0); + Nd4jLong* meanShapeInfo = inputShape->at(1); + const bool applyScale = (bool)INT_ARG(0); const bool applyOffset = (bool)INT_ARG(1); - const int dLdONum = static_cast(applyScale) + static_cast(applyOffset); + DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); - std::vector inArrs(block.width()); - for(int i = 0; i < 4 + dLdONum; ++i) - inArrs[i] = INPUT_VARIABLE(i); + auto shapes = SHAPELIST(); - // check whether all input shapes are mutually broadcastable - Nd4jLong* outShapeInfo = nullptr; - const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); - REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !"); + // dLdI shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, inShapeInfo)); - Nd4jLong* dLdIShapeInfo(nullptr), *dLdMShapeInfo(nullptr), *dLdVShapeInfo(nullptr), *dLdGShapeInfo(nullptr), *dLdBShapeInfo(nullptr); - COPY_SHAPE(inputShape->at(0), dLdIShapeInfo); - COPY_SHAPE(inputShape->at(1), dLdMShapeInfo); - COPY_SHAPE(inputShape->at(2), dLdVShapeInfo); + // dLdM shapeInfo + shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, meanShapeInfo)); - if(applyScale) { - COPY_SHAPE(inputShape->at(3), dLdGShapeInfo); - } - if(applyOffset){ - COPY_SHAPE(inputShape->at(3 + static_cast(applyScale)), dLdBShapeInfo); - } + // dLdV shapeInfo (same as dLdM) + shapes->push_back(shapes->at(shapes->size()-1)); - if(!applyScale && !applyOffset) - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo)); + // dLdG shapeInfo (same as dLdM) + if(applyScale) + shapes->push_back(shapes->at(shapes->size()-1)); - if(applyScale && !applyOffset) - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo)); + // dLdB shapeInfo (same as dLdM) + if(applyOffset) + shapes->push_back(shapes->at(shapes->size()-1)); - if(!applyScale && applyOffset) - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdBShapeInfo)); - - return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo), CONSTANT(dLdBShapeInfo)); + return shapes; } - // ////////////////////////////////////////////////////////////////////////// - // CONFIGURABLE_OP_IMPL(batchnorm_bp, 5, 1, true, 0, 1) { - - // NDArray* input = INPUT_VARIABLE(0); - // NDArray* epsilon = INPUT_VARIABLE(1); - // NDArray* gamma = INPUT_VARIABLE(2); - // NDArray* dGlobalMeanView = INPUT_VARIABLE(3); - // NDArray* dGlobalVarView = INPUT_VARIABLE(4); - // NDArray* outEpsilon = this->getZ(block); - // std::vector argI = *(block.getIArguments()); - // const int bS = epsilon->sizeAt(0); - // bool isLockGammaBeta = (bool)argI[0]; - // const int* epsilonShape = epsilon->getShapeInfo() + 1; - // const T eps = (T)1e-5; - - // int rank = epsilon->rankOf(); - // std::initializer_list dimensions; - // int effectiveBatchSize; - // if (rank == 2) { - // dimensions = {0}; - // effectiveBatchSize = bS; - // } - // else if (rank == 4) { - // dimensions = {0, 2, 3}; - // effectiveBatchSize = input->sizeAt(0)*input->sizeAt(2)*input->sizeAt(3); - // } - // else - // throw "Graph operation batchnorm_bp: the epsilon rank must be equal to 2 or 4 !"; - - // NDArray *mean(nullptr), *var(nullptr), *dBeta(nullptr), *dGamma(nullptr), *dLdVar(nullptr), *dxmu1(nullptr), *dxmu2(nullptr); - // mean = input->template reduceAlongDimension>(dimensions); - // var = input->template varianceAlongDimension>(false, dimensions); - // var->template applyScalar>(eps, nullptr); - // auto std = new NDArray(var->getShapeInfo(), block.getWorkspace()); - // var->template applyTransform>(std, nullptr); - - // auto xMu = new NDArray(input->getShapeInfo(), block.getWorkspace()); - // auto xHat = new NDArray(input->getShapeInfo(), block.getWorkspace()); - // auto temp1 = new NDArray(epsilon->getShapeInfo(), block.getWorkspace()); - // auto temp2 = new NDArray(std->getShapeInfo(), block.getWorkspace()); - // auto dGammaView = new NDArray('c', {1, epsilonShape[1]}, block.getWorkspace()); - // auto dBetaView = new NDArray('c', {1, epsilonShape[1]}, block.getWorkspace()); - // auto dxhat = new NDArray(epsilon->getShapeInfo(), block.getWorkspace()); - - // if (rank == 2) { - // input->subRowVector(mean, xMu); - // xMu->divRowVector(std, xHat); - // } - // else { - // input->template applyBroadcast>({1}, mean, xMu, nullptr); - // xMu->template applyBroadcast>({1}, std, xHat, nullptr); - // } - - // dBeta = epsilon->sum(dimensions); // dL/dBeta = sum_examples dL/dOut - // epsilon->template applyPairwiseTransform>(xHat, temp1, nullptr); //dL/dGamma = sum_examples dL/dOut .* xHat - // dGamma = temp1->sum(dimensions); //dL/dGamma = sum_examples dL/dOut .* xHat - - // if (isLockGammaBeta) - // epsilon->template applyPairwiseTransform>(gamma, dxhat, nullptr); - // else {// Standard case - // if(rank == 2) - // epsilon->mulRowVector(gamma, dxhat); //dL/dxHat = dL/dOut . gamma Shape: [minibatchSize, nOut] - // else - // epsilon->template applyBroadcast>({1}, gamma, dxhat, nullptr); - // } - - // // dLdVar - dL/dVariance, shape: [1, miniBatch] - // dxhat->template applyPairwiseTransform>(xMu, temp1, nullptr); - // dLdVar = temp1->sum(dimensions); - // dLdVar->template applyScalar>((T)-0.5, nullptr); - // T powParams[] = {(T)(-3.)}; - // std->template applyTransform>(temp2, powParams); - // dLdVar->template applyPairwiseTransform>(temp2, nullptr); - - // //dL/dmu - // dxmu1 = dxhat->sum(dimensions); - // dxmu1->template applyPairwiseTransform>(std, nullptr); - // dxmu1->template applyTransform>(); - // dxmu2 = xMu->sum(dimensions); - // dxmu2->template applyScalar>((T)(-2.)/effectiveBatchSize); - // dxmu2->template applyPairwiseTransform>(dLdVar, nullptr); - - // dxmu1->template applyPairwiseTransform>(dxmu2, nullptr); - // NDArray* dLdmu = dxmu1; // = dL/dmu Shape: [1, nOut] - - // //Note the array reuse here: dxhat, xMu, dLdVar, dLdmu - all are invalid after this line (but aren't used later anyway) - // NDArray* dLdx = dxhat; - // dLdVar->template applyScalar>((T)(2.)/effectiveBatchSize); - // dLdmu->template applyScalar>((T)(1.)/effectiveBatchSize); - // if(rank == 2) { - // dLdx->divRowVector(std, dLdx); - // xMu->mulRowVector(dLdVar, xMu); - // } - // else { - // dLdx->template applyBroadcast>({1}, std, dLdx, nullptr); - // xMu->template applyBroadcast>({1}, dLdVar, xMu, nullptr); - // } - // dLdx->template applyPairwiseTransform>(xMu, nullptr); - // if(rank == 2) - // dLdx->addRowVector(dLdmu, dLdx); - // else - // dLdx->template applyBroadcast>({1}, dLdmu, dLdx, nullptr); - - // *outEpsilon = *dLdx; - - // //TODO rework this to avoid the assign here - // // dGammaView->assign(dGamma); - // // dBetaView->assign(dBeta); - // // dGlobalMeanView->assign((T)0.); - // // dGlobalVarView->assign((T)0.); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, dGlobalMeanView); - // // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, dGlobalVarView); - - // delete std; - // delete xMu; - // delete xHat; - // delete mean; - // delete var; - // delete dBeta; - // delete dGamma; - // delete dLdVar; - // delete dxmu1; - // delete dxmu2; - // delete temp1; - // delete temp2; - // delete dxhat; - // delete dGammaView; - // delete dBetaView; - - // return ND4J_STATUS_OK; - // } - - - } diff --git a/libnd4j/include/ops/declarable/headers/nn.h b/libnd4j/include/ops/declarable/headers/nn.h index 313707869..9f9b0e40a 100644 --- a/libnd4j/include/ops/declarable/headers/nn.h +++ b/libnd4j/include/ops/declarable/headers/nn.h @@ -29,12 +29,12 @@ namespace nd4j { #if NOT_EXCLUDED(OP_softmax) DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0); - #endif + #endif /** * Local response normalization implementation as TF. * input: 4D array - * + * * T args: * * 0: bias @@ -42,8 +42,8 @@ namespace nd4j { * 2: beta * * Int arg: depth - optional local radius - * - * output - 4D array + * + * output - 4D array */ #if NOT_EXCLUDED(OP_lrn) DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0); @@ -51,10 +51,10 @@ namespace nd4j { /** * Local response normalization - backprop variant. - * input: + * input: * 0 - 4D array of data * 1 - epsilon - 4D array of approximation - * + * * T args: * * 0: bias @@ -70,34 +70,31 @@ namespace nd4j { #endif /** - * Batch normalization implementation. + * Batch normalization implementation. * Reference: https://arxiv.org/abs/1502.03167v3 - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: * variance: * gamma: * beta: - * + * * Int args: * 0: apply scale * 1: apply offset - * - * + * + * * T args: * 0: epsilon */ #if NOT_EXCLUDED(OP_batchnorm) DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2); #endif - #if NOT_EXCLUDED(OP_batchnorm_new) - DECLARE_CUSTOM_OP(batchnorm_new, 3, 1, false, 1, 2); - #endif /** * back prop in batch normalization - * + * * Expected arguments: * input: input array (any number of dimensions) * mean: @@ -105,11 +102,11 @@ namespace nd4j { * gamma: optional * beta: optional * dLdOut: next epsilon - * + * * Int args: * 0: apply scale - * 1: apply offset - * + * 1: apply offset + * * T args: * 0: epsilon * @@ -117,8 +114,8 @@ namespace nd4j { * dL/dInput * dL/dMean * dL/dVariance - * dL/dGamma - * dL/dBeta + * dL/dGamma, optional + * dL/dBeta, optional */ #if NOT_EXCLUDED(OP_batchnorm) DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2); @@ -131,30 +128,30 @@ namespace nd4j { * x: parameters, any shape * y: gradients. same shape as x * lr: optional, learning rate - * + * * T args: * 0: optional, learning rate */ #if NOT_EXCLUDED(OP_apply_sgd) - DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); + DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); #endif /** * This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167. * Expected arguments: * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where - * bS - batch size - * iH - input height - * iW - input width + * bS - batch size + * iH - input height + * iW - input width * iD - input depth (or number of channels) * scale: 1D input array of scale factors, shape [iD] * offset: 1D input array of offsets (shifts), shape [iD] * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false - * + * * T input arguments: * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x - * + * * integer input arguments: * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW * 1: isTraining, may have two values: zero -> inference, unity -> training diff --git a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp index d6c4da4a1..a0847f704 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/batchnorm.cpp @@ -32,6 +32,8 @@ namespace helpers { template static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector& axes, const double epsilon) { + // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta + NDArray sigmaInvGam(mean); // do not copy mean's buffer, take only its shapeInfo T eps = epsilon; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 4947a39c0..1a2780d52 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -17,6 +17,7 @@ // // @author saudet // @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -28,139 +29,679 @@ #include #include -using namespace mkldnn; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(batchnorm_new) { - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; +namespace nd4j { +namespace ops { +namespace platforms { - auto output = OUTPUT_VARIABLE(0); - const bool applyScale = (bool) INT_ARG(0); - const bool applyOffset = (bool) INT_ARG(1); - const double epsilon = T_ARG(0); +////////////////////////////////////////////////////////////////////////// +static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) { - if (applyScale) - gamma = INPUT_VARIABLE(3); - if (applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) + // also it gives wrong results for formats nhwc and ndhwc - std::vector axes; - if (block.numI() > 2) - for (int i = 2; i < block.numI(); ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf() - 1); + // x -> 2D:nc, 4D:nchw, 5D:ncdhw + // mean -> 1D [c] + // variance -> 1D [c] + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta + // z(output) - same shape as x - std::vector shape({2, mean->lengthOf()}); - NDArray weights = NDArrayFactory::create('c', shape, block.launchContext()); - weights({0, 1, 0, 0}).assign(1.0f); - weights({1, 2, 0, 0}).assign(0.0f); + const int xRank = x->rankOf(); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md( - empty), user_dst_md(empty); + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto norm_flag = normalization_flags::use_global_stats; - if (applyScale || applyOffset) - norm_flag |= normalization_flags::use_scale_shift; + // input type + mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; - mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, - &batchnorm_src_md, nullptr, &batchnorm_dst_md, - &user_src_md, nullptr, &user_dst_md, axes[0]); + // indicate whether gamma or/and beta are given + auto flags = mkldnn::normalization_flags::use_global_stats; + if (weights != nullptr) + flags |= mkldnn::normalization_flags::use_scale_shift; - auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon, norm_flag); + mkldnn::memory::dims dims; + mkldnn::memory::format_tag format; - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn::stream stream(engine); - auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine); - auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); - auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine, - mean->buffer()); - auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine, - variance->buffer()); - auto batchnorm_src_memory = user_src_memory; - mkldnn::memory m(batchnorm_src_md, engine); - if (m.get_desc() != user_src_memory.get_desc()) { - batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine); - reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, - batchnorm_src_memory); - } - auto batchnorm_dst_memory = user_dst_memory; - if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine); - } - if (applyScale || applyOffset) { - if (gamma != nullptr) { - weights({0, 1, 0, 0}).assign(gamma); - } - if (beta != nullptr) { - weights({1, 2, 0, 0}).assign(beta); - } - - auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); - batch_normalization_forward(batchnorm_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, batchnorm_src_memory}, - {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, - {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, - {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, - {MKLDNN_ARG_DST, batchnorm_dst_memory}}); - } else { - batch_normalization_forward(batchnorm_prim_desc).execute(stream, - {{MKLDNN_ARG_SRC, batchnorm_src_memory}, - {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, - {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, - {MKLDNN_ARG_DST, batchnorm_dst_memory}}); - } - if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, - user_dst_memory); - } - stream.wait(); - - return Status::OK(); - } - - PLATFORM_CHECK(batchnorm_new) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - - auto input = INPUT_VARIABLE(0); - auto mean = INPUT_VARIABLE(1); - auto variance = INPUT_VARIABLE(2); - NDArray *gamma = nullptr; - NDArray *beta = nullptr; - - auto output = OUTPUT_VARIABLE(0); - - const bool applyScale = (bool) INT_ARG(0); - const bool applyOffset = (bool) INT_ARG(1); - const double epsilon = T_ARG(0); - - if (applyScale) - gamma = INPUT_VARIABLE(3); - if (applyOffset) - beta = INPUT_VARIABLE(3 + static_cast(applyScale)); - - std::vector axes; - if (block.numI() > 2) - for (int i = 2; i < block.numI(); ++i) - axes.push_back(INT_ARG(i)); - else - axes.push_back(input->rankOf() - 1); - - return block.isUseMKLDNN() && - nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && - axes.size() == 1; - } - } + if(xRank == 2) { + dims = {x->sizeAt(0), x->sizeAt(1)}; + format = mkldnn::memory::format_tag::nc; } + else if(xRank == 4) { + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; + format = mkldnn::memory::format_tag::nchw; + } + else { // xRank = 5 + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; + format = mkldnn::memory::format_tag::ncdhw; + } + + // memory descriptors for arrays + + // x + mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); + x_user_md.data.format_kind = mkldnn_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; + x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; + if(xRank > 2) { + x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; + } + if(xRank > 4) + x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; + + // z, output + mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format); + z_user_md.data.format_kind = mkldnn_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0]; + z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1]; + if(xRank > 2) { + z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2]; + z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3]; + } + if(xRank > 4) + z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4]; + + + // batchnorm forward description + mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + mkldnn::stream stream(engine); + + // provide memory and check whether reorder is required + + // x + auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[MKLDNN_ARG_SRC] = x_mkl_mem; + + // z + auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer()); + const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; + if (zReorder) + mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); + args[MKLDNN_ARG_DST] = z_mkl_mem; + + // mean + auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer()); + args[MKLDNN_ARG_MEAN] = mean_mkl_mem; + + // variance + auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer()); + args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; + + // gamma and beta (and their gradients) if they are present + if(weights != nullptr) { + + auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer()); + args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + } + + // run calculations + mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); +} + + +////////////////////////////////////////////////////////////////////////// +static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, + const float epsilon, NDArray* dLdI, NDArray* dLdW) { + + // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any) + // also it gives wrong results for formats nhwc and ndhwc + + // x -> 2D:nc, 4D:nchw, 5D:ncdhw + // mean -> 1D [c] + // variance -> 1D [c] + // dLdO - same shape as x + // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta + // dLdI - same shape as x + // dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta + + const int xRank = x->rankOf(); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // input type + mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; + + // indicate whether gamma or/and beta are given + auto flags = mkldnn::normalization_flags::use_global_stats; + if (weights != nullptr) + flags |= mkldnn::normalization_flags::use_scale_shift; + + mkldnn::memory::dims dims; + mkldnn::memory::format_tag format; + + if(xRank == 2) { + dims = {x->sizeAt(0), x->sizeAt(1)}; + format = mkldnn::memory::format_tag::nc; + } + else if(xRank == 4) { + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; + format = mkldnn::memory::format_tag::nchw; + } + else { // xRank = 5 + dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; + format = mkldnn::memory::format_tag::ncdhw; + } + + // memory descriptors for arrays + + // x + mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); + x_user_md.data.format_kind = mkldnn_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; + x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; + if(xRank > 2) { + x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; + } + if(xRank > 4) + x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; + + // dLdO + mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format); + dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format + dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0]; + dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1]; + if(xRank > 2) { + dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2]; + dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3]; + } + if(xRank > 4) + dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4]; + + // dLdI + mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format); + mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format); + dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format + dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0]; + dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1]; + if(xRank > 2) { + dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2]; + dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3]; + } + if(xRank > 4) + dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4]; + + // batchnorm forward description + mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); + mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // batchnorm backprop description + mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); + mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + mkldnn::stream stream(engine); + + // provide memory and check whether reorder is required + + // x + auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[MKLDNN_ARG_SRC] = x_mkl_mem; + + // dLdO + auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer()); + const bool dLdOReorder = op_bp_prim_desc.diff_src_desc() != dLdO_user_mem.get_desc(); + auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdO_user_mem; + if (dLdOReorder) + mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem); + args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem; + + // mean + auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); + args[MKLDNN_ARG_MEAN] = mean_mkl_mem; + + // variance + auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer()); + args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; + + // dLdI + auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer()); + const bool dLdIReorder = op_bp_prim_desc.diff_dst_desc() != dLdI_user_mem.get_desc(); + auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdI_user_mem; + args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem; + + // gamma and beta (and their gradients) if they are present + if(weights != nullptr) { + + auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer()); + args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; + + auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer()); + args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; + } + + // run calculations + mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (dLdIReorder) + mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); + + stream.wait(); + + // shape::printArray(dLdI_mkl_mem.map_data(),8); +} + +PLATFORM_IMPL(batchnorm) { + + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const double epsilon = T_ARG(0); + + if(applyScale) + gamma = INPUT_VARIABLE(3); + if(applyOffset) + beta = INPUT_VARIABLE(3 + (int)applyScale); + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma != nullptr) + REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta != nullptr) + REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same (except dLdO) + for(int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_MKLDNN op: types of all input arrays should be the same !"); + + + NDArray *weights = nullptr; + + if(applyScale || applyOffset) { + + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); + + if(applyScale) + (*weights)({0,1, 0,0}).assign(gamma); + else + (*weights)({0,1, 0,0}).assign(1); + if(applyOffset) + (*weights)({1,2, 0,0}).assign(beta); + else + (*weights)({1,2, 0,0}).assign(0); + } + + batchnormMKLDNN(input, mean, variance, weights, epsilon, output); + + delete weights; + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(batchnorm) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + // if (::optimalLevel() < 2) + // return false; + + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + auto mean = INPUT_VARIABLE(1); // [c] + auto variance = INPUT_VARIABLE(2); // [c] + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + auto output = OUTPUT_VARIABLE(0); // same shape as input + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if(applyScale) + gamma = INPUT_VARIABLE(3); + if(applyOffset) + beta = INPUT_VARIABLE(3 + (int)applyScale); + + + const int numOfIntArgs = block.getIArguments()->size(); + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + DataType outType = output->dataType(); + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && + gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); +} + +////////////////////////////////////////////////////////////////////////// +// PLATFORM_IMPL(batchnorm) { + +// auto input = INPUT_VARIABLE(0); +// auto mean = INPUT_VARIABLE(1); +// auto variance = INPUT_VARIABLE(2); +// NDArray *gamma = nullptr; +// NDArray *beta = nullptr; + +// auto output = OUTPUT_VARIABLE(0); + +// const bool applyScale = (bool) INT_ARG(0); +// const bool applyOffset = (bool) INT_ARG(1); +// const double epsilon = T_ARG(0); + +// if (applyScale) +// gamma = INPUT_VARIABLE(3); +// if (applyOffset) +// beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + +// std::vector axes; +// if (block.numI() > 2) +// for (int i = 2; i < block.numI(); ++i) +// axes.push_back(INT_ARG(i)); +// else +// axes.push_back(input->rankOf() - 1); + +// std::vector shape({2, mean->lengthOf()}); +// NDArray weights = NDArrayFactory::create('c', shape, block.launchContext()); +// weights({0, 1, 0, 0}).assign(1.0f); +// weights({1, 2, 0, 0}).assign(0.0f); + +// mkldnn_memory_desc_t empty; +// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty); + +// auto flag = mkldnn::normalization_flags::use_global_stats; +// if (applyScale || applyOffset) +// flag |= mkldnn::normalization_flags::use_scale_shift; + +// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, +// &batchnorm_src_md, nullptr, &batchnorm_dst_md, +// &user_src_md, nullptr, &user_dst_md, axes[0]); + +// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag); + +// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); +// mkldnn::stream stream(engine); +// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine); +// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); +// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); +// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine, +// mean->buffer()); +// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine, +// variance->buffer()); +// auto batchnorm_src_memory = user_src_memory; +// mkldnn::memory m(batchnorm_src_md, engine); +// if (m.get_desc() != user_src_memory.get_desc()) { +// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine); +// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, +// batchnorm_src_memory); +// } +// auto batchnorm_dst_memory = user_dst_memory; +// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { +// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine); +// } +// if (applyScale || applyOffset) { +// if (gamma != nullptr) { +// weights({0, 1, 0, 0}).assign(gamma); +// } +// if (beta != nullptr) { +// weights({1, 2, 0, 0}).assign(beta); +// } + +// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); +// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, +// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, +// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory}, +// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// } else { +// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream, +// {{MKLDNN_ARG_SRC, batchnorm_src_memory}, +// {MKLDNN_ARG_MEAN, batchnorm_mean_memory}, +// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, +// {MKLDNN_ARG_DST, batchnorm_dst_memory}}); +// } +// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { +// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, +// user_dst_memory); +// } +// stream.wait(); + +// return Status::OK(); +// } + +////////////////////////////////////////////////////////////////////////// +// PLATFORM_CHECK(batchnorm) { +// // we don't want to use mkldnn if cpu doesn't support avx/avx2 +// if (::optimalLevel() < 2) +// return false; + +// auto input = INPUT_VARIABLE(0); +// auto mean = INPUT_VARIABLE(1); +// auto variance = INPUT_VARIABLE(2); +// NDArray *gamma = nullptr; +// NDArray *beta = nullptr; + +// auto output = OUTPUT_VARIABLE(0); + +// const bool applyScale = (bool) INT_ARG(0); +// const bool applyOffset = (bool) INT_ARG(1); +// const double epsilon = T_ARG(0); + +// if (applyScale) +// gamma = INPUT_VARIABLE(3); +// if (applyOffset) +// beta = INPUT_VARIABLE(3 + static_cast(applyScale)); + +// std::vector axes; +// if (block.numI() > 2) +// for (int i = 2; i < block.numI(); ++i) +// axes.push_back(INT_ARG(i)); +// else +// axes.push_back(input->rankOf() - 1); + +// return block.isUseMKLDNN() && +// nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && +// axes.size() == 1; +// } + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(batchnorm_bp) { + + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* dLdO = INPUT_VARIABLE(3); // same as input + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + const float epsilon = T_ARG(0); + + if(applyScale) { + gamma = INPUT_VARIABLE(4); + dLdG = OUTPUT_VARIABLE(3); + } + if(applyOffset) { + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.getIArguments()->size(); + const int inRank = input->rankOf(); + + // get axes args to normalize input array over + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(inRank-1); // default dimension to reduce along is last dimension + + const int numOfAxes = axes.size(); + REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); + REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); + REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); + REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); + REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); + if(gamma != nullptr) + REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); + if(beta != nullptr) + REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); + + // types of all input arrays should be the same (except dLdO) + for(int i = 1; i < block.width() - 1; ++i) + REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !"); + + + NDArray *weights = nullptr, *dLdW = nullptr; + + if(applyScale || applyOffset) { + weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); + dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); + if(applyScale) + (*weights)({0,1, 0,0}).assign(gamma); + else + (*weights)({0,1, 0,0}).assign(1); + if(applyOffset) + (*weights)({1,2, 0,0}).assign(beta); + else + (*weights)({1,2, 0,0}).assign(0); + } + + *dLdM = 0; + *dLdV = 0; + + batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW); + + if(applyScale || applyOffset) { + if(applyScale) + dLdG->assign((*dLdW)({0,1, 0,0})); + if(applyOffset) + dLdB->assign((*dLdW)({1,2, 0,0})); + + delete weights; + delete dLdW; + } + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(batchnorm_bp) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + // if (::optimalLevel() < 2) + // return false; + + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* mean = INPUT_VARIABLE(1); // [c] + NDArray* variance = INPUT_VARIABLE(2); // [c] + NDArray* dLdO = INPUT_VARIABLE(3); // same as input + NDArray* gamma = nullptr; // [c] + NDArray* beta = nullptr; // [c] + + NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input + NDArray* dLdM = OUTPUT_VARIABLE(1); // [c] + NDArray* dLdV = OUTPUT_VARIABLE(2); // [c] + NDArray* dLdG = nullptr; // [c] + NDArray* dLdB = nullptr; // [c] + + const bool applyScale = (bool)INT_ARG(0); + const bool applyOffset = (bool)INT_ARG(1); + + if(applyScale) { + gamma = INPUT_VARIABLE(4); + dLdG = OUTPUT_VARIABLE(3); + } + if(applyOffset) { + beta = INPUT_VARIABLE(4 + (int)applyScale); + dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); + } + + const int numOfIntArgs = block.getIArguments()->size(); + std::vector axes; + if(numOfIntArgs > 2) + for(int i = 2; i < numOfIntArgs; ++i) + axes.push_back(INT_ARG(i)); + else + axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension + + DataType inputType = input->dataType(); + DataType meanType = mean->dataType(); + DataType varType = variance->dataType(); + DataType dLdOType = dLdO->dataType(); + DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; + DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32; + + DataType dLdIType = dLdI->dataType(); + DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32; + DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32; + + const int inRank = input->rankOf(); + + return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && + (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && + dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && + dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index e22487f43..10b392465 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -132,8 +132,6 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - mkldnn_memory_desc_t empty; - mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 4fac4a1b7..b84506c3b 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -305,50 +305,50 @@ namespace nd4j { }; - void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { - const Nd4jLong* shape = src->getShapeInfo(); - Nd4jLong rank = shape[0]; - Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - Nd4jLong dim2 = axis >= 2 ? 1 : 2; - Nd4jLong dim3 = axis >= 3 ? 2 : 3; - mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + // void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, + // mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, + // mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { + // const Nd4jLong* shape = src->getShapeInfo(); + // Nd4jLong rank = shape[0]; + // Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + // Nd4jLong dim2 = axis >= 2 ? 1 : 2; + // Nd4jLong dim3 = axis >= 3 ? 2 : 3; + // mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - auto type = mkldnn::memory::data_type::f32; - auto format = mkldnn::memory::format_tag::nchw; - auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" + // auto type = mkldnn::memory::data_type::f32; + // auto format = mkldnn::memory::format_tag::nchw; + // auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any" - if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { - *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - user_src_md->data.format_kind = mkldnn_blocked; // overrides format - user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; - user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; - user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; - user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; - } + // if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { + // *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); + // user_src_md->data.format_kind = mkldnn_blocked; // overrides format + // user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + // user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; + // user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; + // user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; + // } - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { - *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format - user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; - user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; - user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; - user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; - } + // if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { + // *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); + // user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format + // user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; + // user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; + // user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; + // user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; + // } - if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { - *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); - user_dst_md->data.format_kind = mkldnn_blocked; // overrides format - user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; - user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; - user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; - user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; - } - }; + // if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { + // *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + // *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); + // user_dst_md->data.format_kind = mkldnn_blocked; // overrides format + // user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + // user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; + // user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; + // user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; + // } + // }; void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 8e09624e9..14cc41a96 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -62,7 +62,9 @@ namespace nd4j{ DECLARE_PLATFORM(lrn); - DECLARE_PLATFORM(batchnorm_new); + DECLARE_PLATFORM(batchnorm); + + DECLARE_PLATFORM(batchnorm_bp); DECLARE_PLATFORM(lstmLayer); } diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index 2c6de814a..d35346e2b 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -413,7 +413,7 @@ namespace nd4j { return ctx; }; - nd4j::ops::batchnorm_new batchnorm; + nd4j::ops::batchnorm batchnorm; DeclarableBenchmark benchmark(batchnorm, "batchnorm"); output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization"); @@ -1822,7 +1822,7 @@ namespace nd4j { std::string result; long start = nowMs(); - + // set 1 nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", ""); result += fastScalarBenchmark(); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index b6f5f125d..458858c57 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -2385,129 +2385,6 @@ TEST_F(DeclarableOpsTests1, CompactLaunchTests2) { ASSERT_TRUE(exp.equalsTo(&z)); } -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, batchnorm_test1) { - - auto input = NDArrayFactory::create('c', {2,3,2,3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,2,3,2}); - auto gamma = NDArrayFactory::create('c', {2,3,2,3,2}); - auto beta = NDArrayFactory::create('c', {2,3,2,3,2}); - - auto expected = NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011,10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - - -TEST_F(DeclarableOpsTests1, batchnorm_test2) { - - auto input = NDArrayFactory::create('c', {2,3,1,3,1}); - auto mean = NDArrayFactory::create('c', {1,3,2,1,2}); - auto variance = NDArrayFactory::create('c', {2,1,2,3,2}); - auto gamma = NDArrayFactory::create('c', {2,3,2,3,1}); - auto beta = NDArrayFactory::create('c', {1,3,2,1,2}); - - auto expected = NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, batchnorm_test3) { - - auto input = NDArrayFactory::create('c', {2,3,2,3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,1,3,1}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto beta = NDArrayFactory::create('c', {1,2}); - - auto expected = NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011, 10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} - -//////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, batchnorm_test4) { - - auto input = NDArrayFactory::create('c', {3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance= NDArrayFactory::create('c', {2,3,1,3,2}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto beta = NDArrayFactory::create('c', {1,2}); - - auto expected= NDArrayFactory::create('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428}); - - input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); - beta.assign(1.); - - nd4j::ops::batchnorm op; - - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - - ASSERT_EQ(ND4J_STATUS_OK, results->status()); - - auto output = results->at(0); - - ASSERT_TRUE(expected.isSameShapeStrict(output)); - ASSERT_TRUE(expected.equalsTo(output)); - - delete results; -} //////////////////////////////////////////////////////////////////// // TEST_F(DeclarableOpsTests1, sru_old_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 25fe3429a..c1e9ca5e2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2313,7 +2313,35 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { +TEST_F(DeclarableOpsTests10, batchnorm_test1) { + + NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { auto input = NDArrayFactory::create('c', {2,3,4}); auto mean = NDArrayFactory::create('c', {4}); @@ -2330,7 +2358,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { gamma.assign(1.2); beta.assign(1.); - nd4j::ops::batchnorm_new op; + nd4j::ops::batchnorm op; auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); @@ -2346,7 +2374,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { +TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { auto input = NDArrayFactory::create('c', {2,3,4}); auto mean = NDArrayFactory::create('c', {3}, {1.05, 1.1, 1.15}); @@ -2359,7 +2387,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { input.linspace(0.1, 0.1); - nd4j::ops::batchnorm_new op; + nd4j::ops::batchnorm op; auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); @@ -2374,7 +2402,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { } //////////////////////////////////////////////////////////////////// -TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { +TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { auto input = NDArrayFactory::create('c', {2,3,4}); auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); @@ -2387,7 +2415,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { input.linspace(0.1, 0.1); - nd4j::ops::batchnorm_new op; + nd4j::ops::batchnorm op; auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); @@ -2401,6 +2429,63 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, batchnorm_test5) { + + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20., + -19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438, + -12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, batchnorm_test6) { + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 , + 20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 , + -12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + nd4j::ops::batchnorm op; + + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 84a1f2dc9..e36b78a98 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -2883,78 +2883,336 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { - auto input = NDArrayFactory::create('c', {3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,1,3,2}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto beta = NDArrayFactory::create('c', {1,2}); - auto dLdO = NDArrayFactory::create('c', {2,3,2,3,2}); + NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556, 0., 0.254556,0.509112, 0.763668, 1.018224, 1.272779, + 1.527335, 1.781891, 2.036447, 2.291003,2.545559, 2.800115, 3.054671, 3.309227,3.563783, 3.818338, 4.072894, 4.32745}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); mean.assign(1.); variance.assign(0.5); gamma.assign(1.2); - beta.assign(1.); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); - const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &beta, &dLdO}, {1e-5}, {1,1}); + nd4j::ops::batchnorm_bp op; - nd4j::ops::batchnorm opFF; - nd4j::ops::batchnorm_bp opBP; + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); - ASSERT_TRUE(isGradCorrect); + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { - auto input = NDArrayFactory::create('c', {2,3,2,3,2}); - auto mean = NDArrayFactory::create('c', {2,3,2}); - auto variance = NDArrayFactory::create('c', {2,3,1,3,1}); - auto gamma = NDArrayFactory::create('c', {1,1}); - auto dLdO = NDArrayFactory::create('c', {2,3,2,3,2}); + NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE); + NDArray mean ('c', {3}, {1.05, 1.1, 1.15}); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}); + NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}); + NDArray beta ('c', {3}, nd4j::DataType::DOUBLE); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE); + + NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742, 0., 0.251742,0.501992, 0.752989, 1.003985, 1.254981, + 1.527335, 1.781891, 2.036447, 2.291003,2.517418, 2.76916 , 3.020902, 3.272644,3.513947, 3.764943, 4.015939, 4.266936}); + NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}); input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); - gamma.assign(1.2); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma}, {1e-5}, {1,0}); - const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &dLdO}, {1e-5}, {1,0}); + nd4j::ops::batchnorm_bp op; - nd4j::ops::batchnorm opFF; - nd4j::ops::batchnorm_bp opBP; + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); - ASSERT_TRUE(isGradCorrect); + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { - auto input = NDArrayFactory::create('c', {2,3,1,3}); - auto mean = NDArrayFactory::create('c', {1,3,2,1}); - auto variance = NDArrayFactory::create('c', {2,1,2,3}); - auto dLdO = NDArrayFactory::create('c', {2,3,2,3}); + NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE); + NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); + NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); + NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}); + NDArray beta ('c', {2,1,4}, nd4j::DataType::DOUBLE); + NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE); + + NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742, 0., 0.251556,0.509112, 0.755225, 1.003985, 1.25778 , + 1.517885, 1.784991, 2.05947 , 2.341504,2.529808, 2.804986, 3.089205, 3.382173,3.541731, 3.824981, 4.11894 , 4.422841}); + NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }); + NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}); input.linspace(0.1, 0.1); - mean.assign(1.); - variance.assign(0.5); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); - const OpArgsHolder argsHolderFF({&input, &mean, &variance}, {1e-5}, {0,0}); - const OpArgsHolder argsHolderBP({&input, &mean, &variance, &dLdO}, {1e-5}, {0,0}); + nd4j::ops::batchnorm_bp op; - nd4j::ops::batchnorm opFF; - nd4j::ops::batchnorm_bp opBP; + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2}); - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); - ASSERT_TRUE(isGradCorrect); + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { + + NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.9502 , 0.569207, 0.314641}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { + + NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {1.527335, 1.272779,1.018224, 0.763668,-0.466136, -0.233068,0., 0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169, 1.501697,1.716225, 1.930753, + -2.545559, -2.800115,-3.054671, -3.309227,3.262951, 3.496019,3.729087, 3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618, 4.934146,5.148675, 5.363203}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { + + NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528, -0.509112, 0.699204, -0.885433, 1.072641, -1.527335, 1.631475, -1.770866, 1.930753, + -2.545559, 2.563747, -2.656298, 2.788865, -3.563783, 3.496019, -3.541731, 3.646978, -4.582006, 4.42829 , -4.427164, 4.50509 , -5.60023 , 5.360562, -5.312597, 5.363203}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { + + NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584,0.509112, -0.233068, -0., 0.214528,-0.509112, 0.699204, -0.885433, 1.072641,-1.527335, 1.631475, -1.770866, + 1.930753,-2.545559, 2.563747, -2.656298, 2.788865,-3.563783, 3.496019, -3.541731, 3.646978,-4.582006, 4.42829 , -4.427164, + 4.50509 ,-5.60023 , 5.360562, -5.312597, 5.363203, -6.618453, 6.292834, -6.19803 , 6.221315,-7.636677, 7.225105, -7.083463, + 7.079428,-8.6549 , 8.157377, -7.968895, 7.93754 ,-9.673124, 9.089649, -8.854328, 8.795652, -10.691348, 10.02192 , -9.739761, + 9.653765,-11.709571, 10.954192, -10.625194, 10.511877,-12.727795, 11.886464, -11.510627, 11.36999 ,-13.746018, 12.818735, -12.39606 , 12.228102}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { + + NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, nd4j::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2,2}, {1.527335, 1.272779, 1.018224, 0.763668, 0.509112, 0.254556, -0. , -0.254556, 0.466136, 0.699204, 0.932272, 1.16534 , 1.398407, 1.631475, 1.864543, 2.097611, + -2.213582, -2.43494 , -2.656298, -2.877657, -3.099015, -3.320373, -3.541731, -3.76309 , 3.861506, 4.076034, 4.290562, 4.50509 , 4.719618, 4.934146, 5.148675, 5.363203, + -6.618453, -6.873009, -7.127565, -7.382121, -7.636677, -7.891233, -8.145789, -8.400345, 7.924309, 8.157377, 8.390445, 8.623513, 8.856581, 9.089649, 9.322717, 9.555784, + -9.297045, -9.518403, -9.739761, -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405, 10.940933, 11.155462, 11.36999 , 11.584518, 11.799046, 12.013574, 12.228102}, nd4j::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + nd4j::ops::batchnorm_bp op; + + auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto dLdI = results->at(0); + auto dLdG = results->at(3); + auto dLdB = results->at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + delete results; +} /* //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index c95fc11e3..829117bed 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -64,7 +64,7 @@ TEST_F(MklDnnTests, helpers_includer) { nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp; nd4j::ops::platforms::PLATFORM_lrn lrn; - nd4j::ops::platforms::PLATFORM_batchnorm_new batchnorm; + nd4j::ops::platforms::PLATFORM_batchnorm batchnorm; printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm}); #endif diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index a8c50abdf..20ff5918c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -142,7 +142,7 @@ public class BatchNorm extends DynamicCustomOp { @Override public String opName() { - return "batchnorm_new"; + return "batchnorm"; } @Override