Shyrma bn mkl bp (#14)
* - write code for new batchnorm backprop Signed-off-by: Yurii <iuriish@yahoo.com> * - testing batchnorm backprop Signed-off-by: Yurii <iuriish@yahoo.com> * - write code for batchnorm backprop based on mkl dnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in batchnorm_bp mkl dnn Signed-off-by: Yurii <iuriish@yahoo.com> * - made corrections required by reviewer Signed-off-by: Yurii <iuriish@yahoo.com> * - change name in java wrapper for batchnorm op Signed-off-by: Yurii <iuriish@yahoo.com>
This commit is contained in:
		
							parent
							
								
									d333d29099
								
							
						
					
					
						commit
						029a69a835
					
				| @ -60,6 +60,7 @@ namespace nd4j { | ||||
|         Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor); | ||||
|         Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape); | ||||
|         Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape); | ||||
|         Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo); | ||||
| 
 | ||||
|         Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, nd4j::memory::Workspace *workspace); | ||||
|         Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true); | ||||
|  | ||||
| @ -99,6 +99,10 @@ namespace nd4j { | ||||
|         return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); | ||||
|     } | ||||
| 
 | ||||
|     Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) { | ||||
|         return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo))); | ||||
|     } | ||||
| 
 | ||||
|     Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) { | ||||
|         auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); | ||||
|         return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); | ||||
|  | ||||
| @ -102,6 +102,10 @@ namespace nd4j { | ||||
|         return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); | ||||
|     } | ||||
| 
 | ||||
|     Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) { | ||||
|         return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo))); | ||||
|     } | ||||
| 
 | ||||
|     Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) { | ||||
|         auto descriptor = ShapeDescriptor::emptyDescriptor(dataType); | ||||
|         return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); | ||||
|  | ||||
| @ -29,84 +29,8 @@ namespace nd4j { | ||||
| namespace ops { | ||||
| 
 | ||||
| 
 | ||||
| CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {     | ||||
|     auto input    = INPUT_VARIABLE(0); | ||||
|     auto mean     = INPUT_VARIABLE(1); | ||||
|     auto variance = INPUT_VARIABLE(2); | ||||
|     NDArray *gamma    = nullptr; | ||||
|     NDArray *beta     = nullptr; | ||||
| 
 | ||||
|     auto output   = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|     const bool applyScale  = (bool)INT_ARG(0); | ||||
|     const bool applyOffset = (bool)INT_ARG(1); | ||||
| 
 | ||||
|     // FIXME: double?
 | ||||
|     const double epsilon     = T_ARG(0); | ||||
| 
 | ||||
|     if(applyScale) | ||||
|         gamma = INPUT_VARIABLE(3);     | ||||
|     if(applyOffset) | ||||
|         beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));     | ||||
| 
 | ||||
|     std::vector<const NDArray*> inArrs(block.width()); | ||||
|     for(int i = 0; i < block.width(); ++i) | ||||
|         inArrs[i] = INPUT_VARIABLE(i); | ||||
| 
 | ||||
|     // check whether all input shapes are mutually broadcastable
 | ||||
|     Nd4jLong* outShapeInfo = nullptr; | ||||
|     const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); | ||||
|     REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !"); | ||||
| 
 | ||||
|     // normalized output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
 | ||||
| 
 | ||||
|     auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); | ||||
|     if(applyScale) | ||||
|         sigmaInvGam *= *gamma; | ||||
| 
 | ||||
|     NDArray inputMinusMean; | ||||
|     if(!input->isSameShape(output) && !mean->isSameShape(output)) { | ||||
|         auto inputTiled = NDArray(output, false, block.launchContext()); | ||||
|         input->tile(inputTiled); | ||||
|         inputMinusMean = inputTiled - *mean; | ||||
|     } | ||||
|     else | ||||
|         inputMinusMean = *input - *mean;        | ||||
| 
 | ||||
|     if (applyOffset) | ||||
|         output->assign(inputMinusMean * sigmaInvGam + *beta); | ||||
|     else  | ||||
|         output->assign(inputMinusMean * sigmaInvGam); | ||||
| 
 | ||||
|     return Status::OK(); | ||||
| } | ||||
| 
 | ||||
|     DECLARE_TYPES(batchnorm) { | ||||
|         getOpDescriptor() | ||||
|                 ->setAllowedInputTypes(nd4j::DataType::ANY) | ||||
|                 ->setAllowedOutputTypes({ALL_FLOATS}); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| DECLARE_SHAPE_FN(batchnorm) {         | ||||
| 
 | ||||
|     std::vector<const NDArray*> inArrs(block.width()); | ||||
|     auto in = inputShape->at(0); | ||||
|     for(int i = 0; i < block.width(); ++i) | ||||
|         inArrs[i] = INPUT_VARIABLE(i); | ||||
| 
 | ||||
|     // check whether all input shapes are mutually broadcastable
 | ||||
|     Nd4jLong* outShapeInfo = nullptr; | ||||
|     const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); | ||||
|     REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !"); | ||||
| 
 | ||||
|     auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, DataTypeUtils::pickFloatingType(ArrayOptions::dataType(in)))); | ||||
|     return SHAPELIST(result); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { | ||||
| CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) { | ||||
| 
 | ||||
|     auto input    = INPUT_VARIABLE(0); | ||||
|     auto mean     = INPUT_VARIABLE(1); | ||||
| @ -123,7 +47,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { | ||||
|     if(applyScale) | ||||
|         gamma = INPUT_VARIABLE(3); | ||||
|     if(applyOffset) | ||||
|         beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale)); | ||||
|         beta = INPUT_VARIABLE(3 + (int)applyScale); | ||||
| 
 | ||||
|     const int numOfIntArgs = block.getIArguments()->size(); | ||||
|     const int inRank = input->rankOf(); | ||||
| @ -137,30 +61,31 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { | ||||
|         axes.push_back(inRank-1);               // default dimension to reduce along is last dimension
 | ||||
| 
 | ||||
|     const int numOfAxes = axes.size(); | ||||
|     REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_NEW op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); | ||||
| 
 | ||||
|     // get, for example, something like {1, inDim1, 1, inDim3, 1} if axes = {1, 3}
 | ||||
|     std::vector<Nd4jLong> expShapeWithUnities(inRank, 1); | ||||
|     for(int i = 0; i < numOfAxes; ++i) | ||||
|         expShapeWithUnities[axes[i]] = input->sizeAt(axes[i]); | ||||
|     REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); | ||||
| 
 | ||||
|     // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
 | ||||
|     // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
 | ||||
|     std::vector<Nd4jLong> expShape = numOfAxes == 1 ? std::vector<Nd4jLong>(1, input->sizeAt(axes[0])) : expShapeWithUnities; | ||||
|     std::string expShapeStr = ShapeUtils::shapeAsString(expShape); | ||||
|     std::vector<Nd4jLong> expShape; | ||||
|     if(numOfAxes == 1) | ||||
|         expShape.push_back(input->sizeAt(axes[0])); | ||||
|     else {      // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
 | ||||
|         expShape = std::vector<Nd4jLong>(inRank, 1); | ||||
|         for(uint i = 0; i < numOfAxes; ++i) | ||||
|             expShape[axes[i]] = input->sizeAt(axes[i]); | ||||
|     } | ||||
| 
 | ||||
|     REQUIRE_TRUE(ShapeUtils::shapeAsString(mean)     == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of mean array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(mean).c_str()); | ||||
|     REQUIRE_TRUE(ShapeUtils::shapeAsString(variance) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of variance array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(variance).c_str()); | ||||
|     REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); | ||||
|     REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); | ||||
|     if(gamma) | ||||
|         REQUIRE_TRUE(ShapeUtils::shapeAsString(gamma) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of gamma array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(gamma).c_str()); | ||||
|         REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); | ||||
|     if(beta) | ||||
|         REQUIRE_TRUE(ShapeUtils::shapeAsString(beta) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of beta array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(beta).c_str()); | ||||
|         REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); | ||||
| 
 | ||||
|     // types of all input arrays should be the same
 | ||||
|     for(int i = 1; i < block.width(); ++i) | ||||
|         REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !"); | ||||
|         REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM op: types of all input arrays should be the same !"); | ||||
| 
 | ||||
|     nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0); | ||||
|     nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0); | ||||
| 
 | ||||
|     // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
 | ||||
|     helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon); | ||||
| @ -168,15 +93,15 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { | ||||
|     return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| DECLARE_TYPES(batchnorm_new) { | ||||
| DECLARE_TYPES(batchnorm) { | ||||
|     getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true); | ||||
| } | ||||
| 
 | ||||
| DECLARE_SHAPE_FN(batchnorm_new) { | ||||
| DECLARE_SHAPE_FN(batchnorm) { | ||||
| 
 | ||||
|     auto inShapeInfo = inputShape->at(0); | ||||
|     DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); | ||||
|       | ||||
| 
 | ||||
|     auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.getWorkspace());    // output shape is identical to input shape
 | ||||
| 
 | ||||
|     return SHAPELIST(CONSTANT(outShapeInfo)); | ||||
| @ -184,290 +109,177 @@ DECLARE_SHAPE_FN(batchnorm_new) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) { | ||||
|     auto input    = INPUT_VARIABLE(0); | ||||
|     auto mean     = INPUT_VARIABLE(1); | ||||
|     auto variance = INPUT_VARIABLE(2); | ||||
|     NDArray *gamma    = nullptr; | ||||
|     NDArray *beta     = nullptr; | ||||
|     NDArray *dLdO     = nullptr;                 // next epsilon
 | ||||
| 
 | ||||
|     auto dLdI = OUTPUT_VARIABLE(0); | ||||
|     auto dLdM = OUTPUT_VARIABLE(1); | ||||
|     auto dLdV = OUTPUT_VARIABLE(2); | ||||
|     NDArray *dLdG = nullptr; | ||||
|     NDArray *dLdB = nullptr; | ||||
|     NDArray* input    = INPUT_VARIABLE(0); | ||||
|     NDArray* mean     = INPUT_VARIABLE(1); | ||||
|     NDArray* variance = INPUT_VARIABLE(2); | ||||
|     NDArray* dLdO     = INPUT_VARIABLE(3);    // next epsilon
 | ||||
|     NDArray* gamma    = nullptr; | ||||
|     NDArray* beta     = nullptr; | ||||
| 
 | ||||
|     const bool applyScale  = (bool)INT_ARG(0); | ||||
|     const bool applyOffset = (bool)INT_ARG(1); | ||||
| 
 | ||||
|     // FIXME: double?
 | ||||
|     const double    epsilon     = T_ARG(0); | ||||
|     NDArray* dLdI = OUTPUT_VARIABLE(0); | ||||
|     NDArray* dLdM = OUTPUT_VARIABLE(1); | ||||
|     NDArray* dLdV = OUTPUT_VARIABLE(2); | ||||
|     NDArray* dLdG = nullptr; | ||||
|     NDArray* dLdB = nullptr; | ||||
| 
 | ||||
|     const int dLdONum = static_cast<int>(applyScale) + static_cast<int>(applyOffset); | ||||
|     const bool   applyScale  = (bool)INT_ARG(0); | ||||
|     const bool   applyOffset = (bool)INT_ARG(1); | ||||
|     const float  epsilon     = T_ARG(0); | ||||
| 
 | ||||
|     if(applyScale) { | ||||
|         gamma = INPUT_VARIABLE(3); | ||||
|         gamma = INPUT_VARIABLE(4); | ||||
|         dLdG  = OUTPUT_VARIABLE(3); | ||||
|     } | ||||
|     if(applyOffset) { | ||||
|         beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale)); | ||||
|         dLdB = OUTPUT_VARIABLE(3 + static_cast<int>(applyScale)); | ||||
|         beta = INPUT_VARIABLE(4 + (int)applyScale); | ||||
|         dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); | ||||
|     } | ||||
|          | ||||
|     dLdO = INPUT_VARIABLE(3 + dLdONum); | ||||
|      | ||||
|     std::vector<const NDArray*> inArrs(block.width()); | ||||
|     for(int i = 0; i < 4 + dLdONum; ++i) | ||||
|         inArrs[i] = INPUT_VARIABLE(i); | ||||
| 
 | ||||
|     // check whether all input shapes are mutually broadcastable
 | ||||
|     Nd4jLong* outShapeInfo = nullptr; | ||||
|     const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); | ||||
|     REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !"); | ||||
|     const int numOfIntArgs = block.getIArguments()->size(); | ||||
|     const int inRank = input->rankOf(); | ||||
| 
 | ||||
|     // get axes args to normalize input array over
 | ||||
|     std::vector<int> axes; | ||||
|     if(numOfIntArgs > 2) | ||||
|         for(int i = 2; i < numOfIntArgs; ++i) | ||||
|             axes.push_back(INT_ARG(i)); | ||||
|     else | ||||
|         axes.push_back(inRank-1);               // default dimension to reduce along is last dimension
 | ||||
| 
 | ||||
|     const int numOfAxes = axes.size(); | ||||
|     REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank); | ||||
| 
 | ||||
|     // evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
 | ||||
|     // for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
 | ||||
|     std::vector<Nd4jLong> expShape; | ||||
|     if(numOfAxes == 1) | ||||
|         expShape.push_back(input->sizeAt(axes[0])); | ||||
|     else {      // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
 | ||||
|         expShape = std::vector<Nd4jLong>(inRank, 1); | ||||
|         for(uint i = 0; i < numOfAxes; ++i) | ||||
|             expShape[axes[i]] = input->sizeAt(axes[i]); | ||||
|     } | ||||
| 
 | ||||
|     REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str()); | ||||
|     REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str()); | ||||
|     if(gamma) | ||||
|         REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str()); | ||||
|     if(beta) | ||||
|         REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str()); | ||||
| 
 | ||||
|     REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); | ||||
| 
 | ||||
|     // types of all input arrays should be the same (except dLdO)
 | ||||
|     for(int i = 1; i < block.width() - 1; ++i) | ||||
|         if(i != 3) | ||||
|             REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !"); | ||||
| 
 | ||||
|     // ***** calculations ***** //
 | ||||
| 
 | ||||
|     auto sigmaInv = (*variance + epsilon).transform(transform::RSqrt); | ||||
|      | ||||
|     NDArray sigmaInvGamdLdO = -sigmaInv * *dLdO; | ||||
|     if(applyScale) | ||||
|         sigmaInvGamdLdO *= *gamma; | ||||
|     // formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
 | ||||
| 
 | ||||
|     NDArray inputMinusMean; | ||||
|     if(!input->isSameShape(dLdO) && !mean->isSameShape(dLdO)) { | ||||
|         auto inputTiled = NDArray(dLdO, false, block.launchContext()); | ||||
|         input->tile(inputTiled); | ||||
|         inputMinusMean = inputTiled - *mean; | ||||
|     } | ||||
|     else | ||||
|         inputMinusMean = *input - *mean; | ||||
|     // consider mean and variance as constants (since we get them as inputs and don't calculate them)
 | ||||
|     // dLdI = (dLdO * gamma) / (variance + epsilon)^0.5
 | ||||
|     // dLdV = (-0.5  * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5
 | ||||
|     // dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5
 | ||||
|     // dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5
 | ||||
|     // dLdB = dLdO_sum
 | ||||
| 
 | ||||
|     const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes); | ||||
| 
 | ||||
|     NDArray temp1 = *variance + epsilon; | ||||
|     temp1.applyTransform(transform::Reciprocal);            // 1 / (variance + epsilon)
 | ||||
|     auto temp2 = temp1.transform(transform::Sqrt);     // 1 / (variance + epsilon)^0.5
 | ||||
|     if(applyScale) | ||||
|         temp2 *= *gamma;                                    // gamma / (variance + epsilon)^0.5
 | ||||
| 
 | ||||
|     NDArray temp3(input); // empty array with same shape as input
 | ||||
|     input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3);  // input - mean
 | ||||
|     temp3 *= *dLdO;                                                        // (input - mean) * dLdO
 | ||||
| 
 | ||||
|     const bool keepUnitiesInShape = inRank == mean->rankOf(); | ||||
| 
 | ||||
|     // dLdI
 | ||||
|     if(!dLdI->isSameShape(dLdO)) | ||||
|         dLdI->assign( (-sigmaInvGamdLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdI->getShapeInfo(), dLdO->getShapeInfo())) ); | ||||
|     else | ||||
|         dLdI->assign(-sigmaInvGamdLdO); | ||||
|     dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI); | ||||
| 
 | ||||
|     // dLdM
 | ||||
|     if(!dLdM->isSameShape(dLdO)) | ||||
|         dLdM->assign( sigmaInvGamdLdO.reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdM->getShapeInfo(), dLdO->getShapeInfo())) ); | ||||
|     else | ||||
|         dLdM->assign(sigmaInvGamdLdO); | ||||
|     dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape);    // dLdO sum over excluded axes
 | ||||
| 
 | ||||
|     // dLdV
 | ||||
|     if(!dLdV->isSameShape(dLdO)) { | ||||
|         dLdV->assign( (sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdV->getShapeInfo(), dLdO->getShapeInfo())) ); | ||||
|     } | ||||
|     else | ||||
|         dLdV->assign(sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f); | ||||
|     // dLdB
 | ||||
|     if(applyOffset) | ||||
|         dLdB->assign(dLdM); | ||||
| 
 | ||||
|     // dLdM
 | ||||
|     // dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
 | ||||
|     // dLdM->applyTransform(nd4j::transform::Neg);
 | ||||
|     *dLdM = 0;      // put zeros so far
 | ||||
| 
 | ||||
|     //dLdV
 | ||||
|     temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape);     // ((input - mean) * dLdO)_sum
 | ||||
| 
 | ||||
|     // dLdG
 | ||||
|     if(applyScale) { | ||||
|         if(!dLdG->isSameShape(dLdO)) | ||||
|             dLdG->assign( (sigmaInv * inputMinusMean * *dLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdG->getShapeInfo(), dLdO->getShapeInfo())) ); | ||||
|         else | ||||
|             dLdG->assign(sigmaInv * inputMinusMean * *dLdO); | ||||
|         dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG); | ||||
|         // dLdV->assign(dLdG);
 | ||||
|         dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma); | ||||
|     } | ||||
|     else | ||||
|         // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
 | ||||
| 
 | ||||
|     // dLdB
 | ||||
|     if(applyOffset) { | ||||
|         if(!dLdB->isSameShape(dLdO)) | ||||
|             dLdB->assign(dLdO->reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdB->getShapeInfo(), dLdO->getShapeInfo())) ); | ||||
|         else | ||||
|             dLdB->assign(dLdO); | ||||
|     } | ||||
|     // dLdV
 | ||||
|     // dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1);
 | ||||
|     // *dLdV *= -0.5;
 | ||||
|     *dLdV = 0;      // put zeros so far
 | ||||
| 
 | ||||
|     return Status::OK(); | ||||
| } | ||||
| 
 | ||||
|         DECLARE_TYPES(batchnorm_bp) { | ||||
|             getOpDescriptor() | ||||
|                     ->setAllowedInputTypes(0, nd4j::DataType::ANY) | ||||
|                     ->setAllowedInputTypes(1, nd4j::DataType::ANY) | ||||
|                     ->setAllowedInputTypes(2, nd4j::DataType::ANY) | ||||
|                     ->setAllowedInputTypes(3, nd4j::DataType::ANY) | ||||
|                     ->setAllowedInputTypes(4, nd4j::DataType::ANY) | ||||
|                     ->setAllowedInputTypes(5, {ALL_FLOATS}) | ||||
|                     ->setAllowedOutputTypes({ALL_FLOATS}); | ||||
|         } | ||||
| DECLARE_TYPES(batchnorm_bp) { | ||||
|     getOpDescriptor() | ||||
|             ->setAllowedInputTypes(0, nd4j::DataType::ANY) | ||||
|             ->setAllowedInputTypes(1, nd4j::DataType::ANY) | ||||
|             ->setAllowedInputTypes(2, nd4j::DataType::ANY) | ||||
|             ->setAllowedInputTypes(3, {ALL_FLOATS}) | ||||
|             ->setAllowedInputTypes(4, nd4j::DataType::ANY) | ||||
|             ->setAllowedInputTypes(5, nd4j::DataType::ANY) | ||||
|             ->setAllowedOutputTypes({ALL_FLOATS}); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| 
 | ||||
| DECLARE_SHAPE_FN(batchnorm_bp) { | ||||
| 
 | ||||
|     Nd4jLong* inShapeInfo   = inputShape->at(0); | ||||
|     Nd4jLong* meanShapeInfo = inputShape->at(1); | ||||
| 
 | ||||
|     const bool applyScale  = (bool)INT_ARG(0); | ||||
|     const bool applyOffset = (bool)INT_ARG(1); | ||||
| 
 | ||||
|     const int dLdONum = static_cast<int>(applyScale) + static_cast<int>(applyOffset); | ||||
|     DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo)); | ||||
| 
 | ||||
|     std::vector<const NDArray*> inArrs(block.width()); | ||||
|     for(int i = 0; i < 4 + dLdONum; ++i) | ||||
|         inArrs[i] = INPUT_VARIABLE(i); | ||||
|     auto shapes = SHAPELIST(); | ||||
| 
 | ||||
|     // check whether all input shapes are mutually broadcastable
 | ||||
|     Nd4jLong* outShapeInfo = nullptr; | ||||
|     const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace()); | ||||
|     REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !"); | ||||
|     // dLdI shapeInfo
 | ||||
|     shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, inShapeInfo)); | ||||
| 
 | ||||
|     Nd4jLong* dLdIShapeInfo(nullptr), *dLdMShapeInfo(nullptr), *dLdVShapeInfo(nullptr), *dLdGShapeInfo(nullptr), *dLdBShapeInfo(nullptr); | ||||
|     COPY_SHAPE(inputShape->at(0), dLdIShapeInfo); | ||||
|     COPY_SHAPE(inputShape->at(1), dLdMShapeInfo); | ||||
|     COPY_SHAPE(inputShape->at(2), dLdVShapeInfo); | ||||
|     // dLdM shapeInfo
 | ||||
|     shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, meanShapeInfo)); | ||||
| 
 | ||||
|     if(applyScale) { | ||||
|         COPY_SHAPE(inputShape->at(3), dLdGShapeInfo); | ||||
|     } | ||||
|     if(applyOffset){ | ||||
|         COPY_SHAPE(inputShape->at(3 + static_cast<int>(applyScale)), dLdBShapeInfo); | ||||
|     } | ||||
|     // dLdV shapeInfo (same as dLdM)
 | ||||
|     shapes->push_back(shapes->at(shapes->size()-1)); | ||||
| 
 | ||||
|     if(!applyScale && !applyOffset) | ||||
|         return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo)); | ||||
|     // dLdG shapeInfo (same as dLdM)
 | ||||
|     if(applyScale) | ||||
|         shapes->push_back(shapes->at(shapes->size()-1)); | ||||
| 
 | ||||
|     if(applyScale && !applyOffset) | ||||
|         return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo)); | ||||
|     // dLdB shapeInfo (same as dLdM)
 | ||||
|     if(applyOffset) | ||||
|         shapes->push_back(shapes->at(shapes->size()-1)); | ||||
| 
 | ||||
|     if(!applyScale && applyOffset) | ||||
|         return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdBShapeInfo)); | ||||
| 
 | ||||
|     return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo), CONSTANT(dLdBShapeInfo)); | ||||
|     return shapes; | ||||
| } | ||||
|         // //////////////////////////////////////////////////////////////////////////
 | ||||
|         // CONFIGURABLE_OP_IMPL(batchnorm_bp, 5, 1, true, 0, 1) {
 | ||||
| 
 | ||||
|         //     NDArray<T>* input = INPUT_VARIABLE(0);
 | ||||
|         //     NDArray<T>* epsilon = INPUT_VARIABLE(1);
 | ||||
|         //     NDArray<T>* gamma = INPUT_VARIABLE(2);
 | ||||
|         //     NDArray<T>* dGlobalMeanView = INPUT_VARIABLE(3);
 | ||||
|         //     NDArray<T>* dGlobalVarView = INPUT_VARIABLE(4);
 | ||||
|         //     NDArray<T>* outEpsilon = this->getZ(block);
 | ||||
|         //     std::vector<int> argI = *(block.getIArguments());
 | ||||
|         //     const int bS = epsilon->sizeAt(0);
 | ||||
|         //     bool isLockGammaBeta = (bool)argI[0];
 | ||||
|         //     const int* epsilonShape = epsilon->getShapeInfo() + 1;
 | ||||
|         //     const T eps = (T)1e-5;
 | ||||
| 
 | ||||
|         //     int rank = epsilon->rankOf();
 | ||||
|         //     std::initializer_list<int> dimensions;
 | ||||
|         //     int effectiveBatchSize;
 | ||||
|         //     if (rank == 2) {
 | ||||
|         //         dimensions = {0};
 | ||||
|         //         effectiveBatchSize = bS;
 | ||||
|         //     }
 | ||||
|         //     else if (rank == 4) {
 | ||||
|         //         dimensions = {0, 2, 3};
 | ||||
|         //         effectiveBatchSize = input->sizeAt(0)*input->sizeAt(2)*input->sizeAt(3);
 | ||||
|         //     }
 | ||||
|         //     else
 | ||||
|         //         throw "Graph operation batchnorm_bp: the epsilon rank must be equal to 2 or 4 !";
 | ||||
| 
 | ||||
|         //     NDArray<T> *mean(nullptr), *var(nullptr), *dBeta(nullptr), *dGamma(nullptr), *dLdVar(nullptr), *dxmu1(nullptr), *dxmu2(nullptr);
 | ||||
|         //     mean = input->template reduceAlongDimension<simdOps::Mean<T>>(dimensions);
 | ||||
|         //     var = input->template varianceAlongDimension<simdOps::SummaryStatsVariance<T>>(false, dimensions);
 | ||||
|         //     var->template applyScalar<simdOps::Add<T>>(eps, nullptr);
 | ||||
|         //     auto std = new NDArray<T>(var->getShapeInfo(), block.getWorkspace());
 | ||||
|         //     var->template applyTransform<simdOps::Sqrt<T>>(std, nullptr);
 | ||||
| 
 | ||||
|         //     auto xMu = new NDArray<T>(input->getShapeInfo(), block.getWorkspace());
 | ||||
|         //     auto xHat = new NDArray<T>(input->getShapeInfo(), block.getWorkspace());
 | ||||
|         //     auto temp1 = new NDArray<T>(epsilon->getShapeInfo(), block.getWorkspace());
 | ||||
|         //     auto temp2 = new NDArray<T>(std->getShapeInfo(), block.getWorkspace());
 | ||||
|         //     auto dGammaView = new NDArray<T>('c', {1, epsilonShape[1]}, block.getWorkspace());
 | ||||
|         //     auto dBetaView = new NDArray<T>('c', {1, epsilonShape[1]}, block.getWorkspace());
 | ||||
|         //     auto dxhat = new NDArray<T>(epsilon->getShapeInfo(), block.getWorkspace());
 | ||||
| 
 | ||||
|         //     if (rank == 2) {
 | ||||
|         //         input->subRowVector(mean, xMu);
 | ||||
|         //         xMu->divRowVector(std, xHat);
 | ||||
|         //     }
 | ||||
|         //     else {
 | ||||
|         //         input->template applyBroadcast<simdOps::Subtract<T>>({1}, mean, xMu, nullptr);
 | ||||
|         //         xMu->template applyBroadcast<simdOps::Divide<T>>({1}, std, xHat, nullptr);
 | ||||
|         //     }
 | ||||
| 
 | ||||
|         //     dBeta = epsilon->sum(dimensions); // dL/dBeta = sum_examples dL/dOut
 | ||||
|         //     epsilon->template applyPairwiseTransform<simdOps::Multiply<T>>(xHat, temp1, nullptr);   //dL/dGamma = sum_examples dL/dOut .* xHat
 | ||||
|         //     dGamma = temp1->sum(dimensions);  //dL/dGamma = sum_examples dL/dOut .* xHat
 | ||||
| 
 | ||||
|         //     if (isLockGammaBeta)
 | ||||
|         //         epsilon->template applyPairwiseTransform<simdOps::Multiply<T>>(gamma, dxhat, nullptr);
 | ||||
|         //     else {// Standard case
 | ||||
|         //         if(rank == 2)
 | ||||
|         //             epsilon->mulRowVector(gamma, dxhat); //dL/dxHat = dL/dOut . gamma        Shape: [minibatchSize, nOut]
 | ||||
|         //         else
 | ||||
|         //             epsilon->template applyBroadcast<simdOps::Multiply<T>>({1}, gamma, dxhat, nullptr);
 | ||||
|         //     }
 | ||||
| 
 | ||||
|         //     // dLdVar - dL/dVariance, shape: [1, miniBatch]
 | ||||
|         //     dxhat->template applyPairwiseTransform<simdOps::Multiply<T>>(xMu, temp1, nullptr);
 | ||||
|         //     dLdVar = temp1->sum(dimensions);
 | ||||
|         //     dLdVar->template applyScalar<simdOps::Multiply<T>>((T)-0.5, nullptr);
 | ||||
|         //     T powParams[] = {(T)(-3.)};
 | ||||
|         //     std->template applyTransform<simdOps::Pow<T>>(temp2, powParams);
 | ||||
|         //     dLdVar->template applyPairwiseTransform<simdOps::Multiply<T>>(temp2, nullptr);
 | ||||
| 
 | ||||
|         //     //dL/dmu
 | ||||
|         //     dxmu1 = dxhat->sum(dimensions);
 | ||||
|         //     dxmu1->template applyPairwiseTransform<simdOps::Divide<T>>(std, nullptr);
 | ||||
|         //     dxmu1->template applyTransform<simdOps::Neg<T>>();
 | ||||
|         //     dxmu2 = xMu->sum(dimensions);
 | ||||
|         //     dxmu2->template applyScalar<simdOps::Multiply<T>>((T)(-2.)/effectiveBatchSize);
 | ||||
|         //     dxmu2->template applyPairwiseTransform<simdOps::Multiply<T>>(dLdVar, nullptr);
 | ||||
| 
 | ||||
|         //     dxmu1->template applyPairwiseTransform<simdOps::Add<T>>(dxmu2, nullptr);
 | ||||
|         //     NDArray<T>* dLdmu = dxmu1;      //  = dL/dmu Shape: [1, nOut]
 | ||||
| 
 | ||||
|         //     //Note the array reuse here: dxhat, xMu, dLdVar, dLdmu - all are invalid after this line (but aren't used later anyway)
 | ||||
|         //     NDArray<T>* dLdx = dxhat;
 | ||||
|         //     dLdVar->template applyScalar<simdOps::Multiply<T>>((T)(2.)/effectiveBatchSize);
 | ||||
|         //     dLdmu->template applyScalar<simdOps::Multiply<T>>((T)(1.)/effectiveBatchSize);
 | ||||
|         //     if(rank == 2) {
 | ||||
|         //         dLdx->divRowVector(std, dLdx);
 | ||||
|         //         xMu->mulRowVector(dLdVar, xMu);
 | ||||
|         //     }
 | ||||
|         //     else {
 | ||||
|         //         dLdx->template applyBroadcast<simdOps::Divide<T>>({1}, std, dLdx, nullptr);
 | ||||
|         //         xMu->template applyBroadcast<simdOps::Multiply<T>>({1}, dLdVar, xMu, nullptr);
 | ||||
|         //     }
 | ||||
|         //     dLdx->template applyPairwiseTransform<simdOps::Add<T>>(xMu, nullptr);
 | ||||
|         //     if(rank == 2)
 | ||||
|         //         dLdx->addRowVector(dLdmu, dLdx);
 | ||||
|         //     else
 | ||||
|         //         dLdx->template applyBroadcast<simdOps::Add<T>>({1}, dLdmu, dLdx, nullptr);
 | ||||
| 
 | ||||
|         //     *outEpsilon = *dLdx;
 | ||||
| 
 | ||||
|         //     //TODO rework this to avoid the assign here
 | ||||
|         //     // dGammaView->assign(dGamma);
 | ||||
|         //     // dBetaView->assign(dBeta);
 | ||||
|         //     // dGlobalMeanView->assign((T)0.);
 | ||||
|         //     // dGlobalVarView->assign((T)0.);
 | ||||
|         //     // retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
 | ||||
|         //     // retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
 | ||||
|         //     // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, dGlobalMeanView);
 | ||||
|         //     // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, dGlobalVarView);
 | ||||
| 
 | ||||
|         //     delete std;
 | ||||
|         //     delete xMu;
 | ||||
|         //     delete xHat;
 | ||||
|         //     delete mean;
 | ||||
|         //     delete var;
 | ||||
|         //     delete dBeta;
 | ||||
|         //     delete dGamma;
 | ||||
|         //     delete dLdVar;
 | ||||
|         //     delete dxmu1;
 | ||||
|         //     delete dxmu2;
 | ||||
|         //     delete temp1;
 | ||||
|         //     delete temp2;
 | ||||
|         //     delete dxhat;
 | ||||
|         //     delete dGammaView;
 | ||||
|         //     delete dBetaView;
 | ||||
| 
 | ||||
|         //     return ND4J_STATUS_OK;
 | ||||
|         // }
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
|  | ||||
| @ -29,12 +29,12 @@ namespace nd4j { | ||||
|         #if NOT_EXCLUDED(OP_softmax) | ||||
|         DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0); | ||||
|         DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0); | ||||
|         #endif        | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|          * Local response normalization implementation as TF. | ||||
|          * input: 4D array | ||||
|          *  | ||||
|          * | ||||
|          * T args: | ||||
|          * | ||||
|          * 0: bias | ||||
| @ -42,8 +42,8 @@ namespace nd4j { | ||||
|          * 2: beta | ||||
|          * | ||||
|          * Int arg: depth - optional local radius | ||||
|          *  | ||||
|          * output - 4D array  | ||||
|          * | ||||
|          * output - 4D array | ||||
|          */ | ||||
|         #if NOT_EXCLUDED(OP_lrn) | ||||
|         DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0); | ||||
| @ -51,10 +51,10 @@ namespace nd4j { | ||||
| 
 | ||||
|         /**
 | ||||
|          * Local response normalization - backprop variant. | ||||
|          * input:  | ||||
|          * input: | ||||
|          *  0 - 4D array of data | ||||
|          *  1 - epsilon - 4D array of approximation | ||||
|          *  | ||||
|          * | ||||
|          * T args: | ||||
|          * | ||||
|          * 0: bias | ||||
| @ -70,34 +70,31 @@ namespace nd4j { | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|         * Batch normalization implementation.  | ||||
|         * Batch normalization implementation. | ||||
|         * Reference: https://arxiv.org/abs/1502.03167v3
 | ||||
|         *  | ||||
|         * | ||||
|         * Expected arguments: | ||||
|         * input: input array (any number of dimensions) | ||||
|         * mean: | ||||
|         * variance: | ||||
|         * gamma: | ||||
|         * beta: | ||||
|         *  | ||||
|         * | ||||
|         * Int args: | ||||
|         * 0: apply scale | ||||
|         * 1: apply offset | ||||
|         *  | ||||
|         *  | ||||
|         * | ||||
|         * | ||||
|         * T args: | ||||
|         * 0: epsilon | ||||
|         */ | ||||
|         #if NOT_EXCLUDED(OP_batchnorm) | ||||
|         DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2); | ||||
|         #endif | ||||
|         #if NOT_EXCLUDED(OP_batchnorm_new) | ||||
|         DECLARE_CUSTOM_OP(batchnorm_new, 3, 1, false, 1, 2); | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|         * back prop in batch normalization | ||||
|         *  | ||||
|         * | ||||
|         * Expected arguments: | ||||
|         * input: input array (any number of dimensions) | ||||
|         * mean: | ||||
| @ -105,11 +102,11 @@ namespace nd4j { | ||||
|         * gamma: optional | ||||
|         * beta: optional | ||||
|         * dLdOut: next epsilon | ||||
|         *  | ||||
|         * | ||||
|         * Int args: | ||||
|         * 0: apply scale | ||||
|         * 1: apply offset  | ||||
|         *  | ||||
|         * 1: apply offset | ||||
|         * | ||||
|         * T args: | ||||
|         * 0: epsilon | ||||
|         * | ||||
| @ -117,8 +114,8 @@ namespace nd4j { | ||||
|         * dL/dInput | ||||
|         * dL/dMean | ||||
|         * dL/dVariance | ||||
|         * dL/dGamma | ||||
|         * dL/dBeta | ||||
|         * dL/dGamma, optional | ||||
|         * dL/dBeta, optional | ||||
|         */ | ||||
|         #if NOT_EXCLUDED(OP_batchnorm) | ||||
|         DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2); | ||||
| @ -131,30 +128,30 @@ namespace nd4j { | ||||
|          * x: parameters, any shape | ||||
|          * y: gradients. same shape as x | ||||
|          * lr: optional, learning rate | ||||
|          *  | ||||
|          * | ||||
|          * T args: | ||||
|          * 0: optional, learning rate | ||||
|          */ | ||||
|         #if NOT_EXCLUDED(OP_apply_sgd) | ||||
|         DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0);    | ||||
|         DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0); | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|          * This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167.
 | ||||
|          * Expected arguments: | ||||
|          * x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where | ||||
|          *    bS - batch size  | ||||
|          *    iH - input height     | ||||
|          *    iW - input width  | ||||
|          *    bS - batch size | ||||
|          *    iH - input height | ||||
|          *    iW - input width | ||||
|          *    iD - input depth (or number of channels) | ||||
|          * scale:  1D input array of scale factors, shape [iD] | ||||
|          * offset: 1D input array of offsets (shifts), shape [iD] | ||||
|          * mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false | ||||
|          * variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false | ||||
|          *  | ||||
|          * | ||||
|          * T input arguments: | ||||
|          * 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x | ||||
|          *  | ||||
|          * | ||||
|          * integer input arguments: | ||||
|          * 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW | ||||
|          * 1: isTraining, may have two values: zero -> inference, unity -> training | ||||
|  | ||||
| @ -32,6 +32,8 @@ namespace helpers { | ||||
| template <typename T> | ||||
| static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector<int>& axes, const double epsilon) { | ||||
| 
 | ||||
|     // formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
 | ||||
| 
 | ||||
|     NDArray sigmaInvGam(mean);  // do not copy mean's buffer, take only its shapeInfo
 | ||||
|     T eps = epsilon; | ||||
| 
 | ||||
|  | ||||
| @ -17,6 +17,7 @@ | ||||
| //
 | ||||
| // @author saudet
 | ||||
| // @author raver119@gmail.com
 | ||||
| // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||
| //
 | ||||
| 
 | ||||
| #include <ops/declarable/PlatformHelper.h> | ||||
| @ -28,139 +29,679 @@ | ||||
| #include <ops/declarable/helpers/convolutions.h> | ||||
| #include <NDArrayFactory.h> | ||||
| 
 | ||||
| using namespace mkldnn; | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         namespace platforms { | ||||
|             PLATFORM_IMPL(batchnorm_new) { | ||||
|                 auto input = INPUT_VARIABLE(0); | ||||
|                 auto mean = INPUT_VARIABLE(1); | ||||
|                 auto variance = INPUT_VARIABLE(2); | ||||
|                 NDArray *gamma = nullptr; | ||||
|                 NDArray *beta = nullptr; | ||||
| namespace nd4j      { | ||||
| namespace ops       { | ||||
| namespace platforms { | ||||
| 
 | ||||
|                 auto output = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|                 const bool applyScale = (bool) INT_ARG(0); | ||||
|                 const bool applyOffset = (bool) INT_ARG(1); | ||||
|                 const double epsilon = T_ARG(0); | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) { | ||||
| 
 | ||||
|                 if (applyScale) | ||||
|                     gamma = INPUT_VARIABLE(3); | ||||
|                 if (applyOffset) | ||||
|                     beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale)); | ||||
|     // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
 | ||||
|     // also it gives wrong results for formats nhwc and ndhwc
 | ||||
| 
 | ||||
|                 std::vector<int> axes; | ||||
|                 if (block.numI() > 2) | ||||
|                     for (int i = 2; i < block.numI(); ++i) | ||||
|                         axes.push_back(INT_ARG(i)); | ||||
|                 else | ||||
|                     axes.push_back(input->rankOf() - 1); | ||||
|     // x -> 2D:nc, 4D:nchw, 5D:ncdhw
 | ||||
|     // mean -> 1D [c]
 | ||||
|     // variance -> 1D [c]
 | ||||
|     // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
 | ||||
|     // z(output) - same shape as x
 | ||||
| 
 | ||||
|                 std::vector<Nd4jLong> shape({2, mean->lengthOf()}); | ||||
|                 NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext()); | ||||
|                 weights({0, 1, 0, 0}).assign(1.0f); | ||||
|                 weights({1, 2, 0, 0}).assign(0.0f); | ||||
|     const int xRank = x->rankOf(); | ||||
| 
 | ||||
|                 mkldnn_memory_desc_t empty; | ||||
|                 mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md( | ||||
|                         empty), user_dst_md(empty); | ||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||
| 
 | ||||
|                 auto norm_flag = normalization_flags::use_global_stats; | ||||
|                 if (applyScale || applyOffset) | ||||
|                     norm_flag |= normalization_flags::use_scale_shift; | ||||
|     // input type
 | ||||
|     mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; | ||||
| 
 | ||||
|                 mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output, | ||||
|                                                           &batchnorm_src_md, nullptr, &batchnorm_dst_md, | ||||
|                                                           &user_src_md, nullptr, &user_dst_md, axes[0]); | ||||
|     // indicate whether gamma or/and beta are given
 | ||||
|     auto flags = mkldnn::normalization_flags::use_global_stats; | ||||
|     if (weights != nullptr) | ||||
|         flags |= mkldnn::normalization_flags::use_scale_shift; | ||||
| 
 | ||||
|                 auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon, norm_flag); | ||||
|     mkldnn::memory::dims dims; | ||||
|     mkldnn::memory::format_tag format; | ||||
| 
 | ||||
|                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||
|                 mkldnn::stream stream(engine); | ||||
|                 auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine); | ||||
|                 auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer()); | ||||
|                 auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer()); | ||||
|                 auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine, | ||||
|                                                             mean->buffer()); | ||||
|                 auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine, | ||||
|                                                                 variance->buffer()); | ||||
|                 auto batchnorm_src_memory = user_src_memory; | ||||
|                 mkldnn::memory m(batchnorm_src_md, engine); | ||||
|                 if (m.get_desc() != user_src_memory.get_desc()) { | ||||
|                     batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine); | ||||
|                     reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory, | ||||
|                                                                            batchnorm_src_memory); | ||||
|                 } | ||||
|                 auto batchnorm_dst_memory = user_dst_memory; | ||||
|                 if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { | ||||
|                     batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine); | ||||
|                 } | ||||
|                 if (applyScale || applyOffset) { | ||||
|                     if (gamma != nullptr) { | ||||
|                         weights({0, 1, 0, 0}).assign(gamma); | ||||
|                     } | ||||
|                     if (beta != nullptr) { | ||||
|                         weights({1, 2, 0, 0}).assign(beta); | ||||
|                     } | ||||
| 
 | ||||
|                     auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer()); | ||||
|                     batch_normalization_forward(batchnorm_prim_desc).execute(stream, | ||||
|                                                                              {{MKLDNN_ARG_SRC,      batchnorm_src_memory}, | ||||
|                                                                               {MKLDNN_ARG_MEAN,     batchnorm_mean_memory}, | ||||
|                                                                               {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, | ||||
|                                                                               {MKLDNN_ARG_WEIGHTS,  batchnorm_weights_memory}, | ||||
|                                                                               {MKLDNN_ARG_DST,      batchnorm_dst_memory}}); | ||||
|                 } else { | ||||
|                     batch_normalization_forward(batchnorm_prim_desc).execute(stream, | ||||
|                                                                              {{MKLDNN_ARG_SRC,      batchnorm_src_memory}, | ||||
|                                                                               {MKLDNN_ARG_MEAN,     batchnorm_mean_memory}, | ||||
|                                                                               {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory}, | ||||
|                                                                               {MKLDNN_ARG_DST,      batchnorm_dst_memory}}); | ||||
|                 } | ||||
|                 if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) { | ||||
|                     reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory, | ||||
|                                                                            user_dst_memory); | ||||
|                 } | ||||
|                 stream.wait(); | ||||
| 
 | ||||
|                 return Status::OK(); | ||||
|             } | ||||
| 
 | ||||
|             PLATFORM_CHECK(batchnorm_new) { | ||||
|                 // we don't want to use mkldnn if cpu doesn't support avx/avx2
 | ||||
|                 if (::optimalLevel() < 2) | ||||
|                     return false; | ||||
| 
 | ||||
|                 auto input = INPUT_VARIABLE(0); | ||||
|                 auto mean = INPUT_VARIABLE(1); | ||||
|                 auto variance = INPUT_VARIABLE(2); | ||||
|                 NDArray *gamma = nullptr; | ||||
|                 NDArray *beta = nullptr; | ||||
| 
 | ||||
|                 auto output = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|                 const bool applyScale = (bool) INT_ARG(0); | ||||
|                 const bool applyOffset = (bool) INT_ARG(1); | ||||
|                 const double epsilon = T_ARG(0); | ||||
| 
 | ||||
|                 if (applyScale) | ||||
|                     gamma = INPUT_VARIABLE(3); | ||||
|                 if (applyOffset) | ||||
|                     beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale)); | ||||
| 
 | ||||
|                 std::vector<int> axes; | ||||
|                 if (block.numI() > 2) | ||||
|                     for (int i = 2; i < block.numI(); ++i) | ||||
|                         axes.push_back(INT_ARG(i)); | ||||
|                 else | ||||
|                     axes.push_back(input->rankOf() - 1); | ||||
| 
 | ||||
|                 return block.isUseMKLDNN() && | ||||
|                        nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) && | ||||
|                        axes.size() == 1; | ||||
|             } | ||||
|         } | ||||
|     if(xRank == 2) { | ||||
|         dims = {x->sizeAt(0), x->sizeAt(1)}; | ||||
|         format = mkldnn::memory::format_tag::nc; | ||||
|     } | ||||
|     else if(xRank == 4) { | ||||
|         dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; | ||||
|         format = mkldnn::memory::format_tag::nchw; | ||||
|     } | ||||
|     else {  // xRank = 5
 | ||||
|         dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; | ||||
|         format = mkldnn::memory::format_tag::ncdhw; | ||||
|     } | ||||
| 
 | ||||
|     // memory descriptors for arrays
 | ||||
| 
 | ||||
|     // x
 | ||||
|     mkldnn::memory::desc x_mkl_md  = mkldnn::memory::desc(dims, type, format); | ||||
|     mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); | ||||
|     x_user_md.data.format_kind = mkldnn_blocked;    // overrides format
 | ||||
|     x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; | ||||
|     x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; | ||||
|     if(xRank > 2) { | ||||
|         x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; | ||||
|         x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; | ||||
|     } | ||||
|     if(xRank > 4) | ||||
|         x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; | ||||
| 
 | ||||
|     // z, output
 | ||||
|     mkldnn::memory::desc z_mkl_md  = mkldnn::memory::desc(dims, type, format); | ||||
|     mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format); | ||||
|     z_user_md.data.format_kind = mkldnn_blocked;    // overrides format
 | ||||
|     z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0]; | ||||
|     z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1]; | ||||
|     if(xRank > 2) { | ||||
|         z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2]; | ||||
|         z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3]; | ||||
|     } | ||||
|     if(xRank > 4) | ||||
|         z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4]; | ||||
| 
 | ||||
| 
 | ||||
|     // batchnorm forward description
 | ||||
|     mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); | ||||
|     mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); | ||||
| 
 | ||||
|     // arguments (memory buffers) necessary for calculations
 | ||||
|     std::unordered_map<int, mkldnn::memory> args; | ||||
| 
 | ||||
|     mkldnn::stream stream(engine); | ||||
| 
 | ||||
|     // provide memory and check whether reorder is required
 | ||||
| 
 | ||||
|     // x
 | ||||
|     auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); | ||||
|     const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); | ||||
|     auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; | ||||
|     if (xReorder) | ||||
|         mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); | ||||
|     args[MKLDNN_ARG_SRC] = x_mkl_mem; | ||||
| 
 | ||||
|     // z
 | ||||
|     auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer()); | ||||
|     const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); | ||||
|     auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; | ||||
|     if (zReorder) | ||||
|         mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); | ||||
|     args[MKLDNN_ARG_DST] = z_mkl_mem; | ||||
| 
 | ||||
|     // mean
 | ||||
|     auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer()); | ||||
|     args[MKLDNN_ARG_MEAN] = mean_mkl_mem; | ||||
| 
 | ||||
|     // variance
 | ||||
|     auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer()); | ||||
|     args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; | ||||
| 
 | ||||
|     // gamma and beta (and their gradients) if they are present
 | ||||
|     if(weights != nullptr) { | ||||
| 
 | ||||
|         auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer()); | ||||
|         args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; | ||||
|     } | ||||
| 
 | ||||
|     // run calculations
 | ||||
|     mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); | ||||
| 
 | ||||
|     // reorder outputs if necessary
 | ||||
|     if (zReorder) | ||||
|         mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); | ||||
| 
 | ||||
|     stream.wait(); | ||||
| 
 | ||||
|     // shape::printArray(z_mkl_mem.map_data<float>(),8);
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, | ||||
|                                     const float epsilon, NDArray* dLdI, NDArray* dLdW) { | ||||
| 
 | ||||
|     // unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
 | ||||
|     // also it gives wrong results for formats nhwc and ndhwc
 | ||||
| 
 | ||||
|     // x -> 2D:nc, 4D:nchw, 5D:ncdhw
 | ||||
|     // mean -> 1D [c]
 | ||||
|     // variance -> 1D [c]
 | ||||
|     // dLdO - same shape as x
 | ||||
|     // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
 | ||||
|     // dLdI - same shape as x
 | ||||
|     // dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta
 | ||||
| 
 | ||||
|     const int xRank = x->rankOf(); | ||||
| 
 | ||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||
| 
 | ||||
|     // input type
 | ||||
|     mkldnn::memory::data_type type = mkldnn::memory::data_type::f32; | ||||
| 
 | ||||
|     // indicate whether gamma or/and beta are given
 | ||||
|     auto flags = mkldnn::normalization_flags::use_global_stats; | ||||
|     if (weights != nullptr) | ||||
|         flags |= mkldnn::normalization_flags::use_scale_shift; | ||||
| 
 | ||||
|     mkldnn::memory::dims dims; | ||||
|     mkldnn::memory::format_tag format; | ||||
| 
 | ||||
|     if(xRank == 2) { | ||||
|         dims = {x->sizeAt(0), x->sizeAt(1)}; | ||||
|         format = mkldnn::memory::format_tag::nc; | ||||
|     } | ||||
|     else if(xRank == 4) { | ||||
|         dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; | ||||
|         format = mkldnn::memory::format_tag::nchw; | ||||
|     } | ||||
|     else {  // xRank = 5
 | ||||
|         dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; | ||||
|         format = mkldnn::memory::format_tag::ncdhw; | ||||
|     } | ||||
| 
 | ||||
|     // memory descriptors for arrays
 | ||||
| 
 | ||||
|     // x
 | ||||
|     mkldnn::memory::desc x_mkl_md  = mkldnn::memory::desc(dims, type, format); | ||||
|     mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format); | ||||
|     x_user_md.data.format_kind = mkldnn_blocked;    // overrides format
 | ||||
|     x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; | ||||
|     x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; | ||||
|     if(xRank > 2) { | ||||
|         x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; | ||||
|         x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; | ||||
|     } | ||||
|     if(xRank > 4) | ||||
|         x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; | ||||
| 
 | ||||
|     // dLdO
 | ||||
|     mkldnn::memory::desc dLdO_mkl_md  = mkldnn::memory::desc(dims, type, format); | ||||
|     mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format); | ||||
|     dLdO_user_md.data.format_kind = mkldnn_blocked;    // overrides format
 | ||||
|     dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0]; | ||||
|     dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1]; | ||||
|     if(xRank > 2) { | ||||
|         dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2]; | ||||
|         dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3]; | ||||
|     } | ||||
|     if(xRank > 4) | ||||
|         dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4]; | ||||
| 
 | ||||
|     // dLdI
 | ||||
|     mkldnn::memory::desc dLdI_mkl_md  = mkldnn::memory::desc(dims, type, format); | ||||
|     mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format); | ||||
|     dLdI_user_md.data.format_kind = mkldnn_blocked;    // overrides format
 | ||||
|     dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0]; | ||||
|     dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1]; | ||||
|     if(xRank > 2) { | ||||
|         dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2]; | ||||
|         dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3]; | ||||
|     } | ||||
|     if(xRank > 4) | ||||
|         dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4]; | ||||
| 
 | ||||
|     // batchnorm forward description
 | ||||
|     mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags); | ||||
|     mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); | ||||
| 
 | ||||
|     // batchnorm backprop description
 | ||||
|     mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags); | ||||
|     mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); | ||||
| 
 | ||||
|     // arguments (memory buffers) necessary for calculations
 | ||||
|     std::unordered_map<int, mkldnn::memory> args; | ||||
| 
 | ||||
|     mkldnn::stream stream(engine); | ||||
| 
 | ||||
|     // provide memory and check whether reorder is required
 | ||||
| 
 | ||||
|     // x
 | ||||
|     auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); | ||||
|     const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc(); | ||||
|     auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem; | ||||
|     if (xReorder) | ||||
|         mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); | ||||
|     args[MKLDNN_ARG_SRC] = x_mkl_mem; | ||||
| 
 | ||||
|     // dLdO
 | ||||
|     auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer()); | ||||
|     const bool dLdOReorder = op_bp_prim_desc.diff_src_desc() != dLdO_user_mem.get_desc(); | ||||
|     auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdO_user_mem; | ||||
|     if (dLdOReorder) | ||||
|         mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem); | ||||
|     args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem; | ||||
| 
 | ||||
|     // mean
 | ||||
|     auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); | ||||
|     args[MKLDNN_ARG_MEAN] = mean_mkl_mem; | ||||
| 
 | ||||
|     // variance
 | ||||
|     auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer()); | ||||
|     args[MKLDNN_ARG_VARIANCE] = var_mkl_mem; | ||||
| 
 | ||||
|     // dLdI
 | ||||
|     auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer()); | ||||
|     const bool dLdIReorder = op_bp_prim_desc.diff_dst_desc() != dLdI_user_mem.get_desc(); | ||||
|     auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdI_user_mem; | ||||
|     args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem; | ||||
| 
 | ||||
|     // gamma and beta (and their gradients) if they are present
 | ||||
|     if(weights != nullptr) { | ||||
| 
 | ||||
|         auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer()); | ||||
|         args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem; | ||||
| 
 | ||||
|         auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer()); | ||||
|         args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem; | ||||
|     } | ||||
| 
 | ||||
|     // run calculations
 | ||||
|     mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); | ||||
| 
 | ||||
|     // reorder outputs if necessary
 | ||||
|     if (dLdIReorder) | ||||
|         mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); | ||||
| 
 | ||||
|     stream.wait(); | ||||
| 
 | ||||
|     // shape::printArray(dLdI_mkl_mem.map_data<float>(),8);
 | ||||
| } | ||||
| 
 | ||||
| PLATFORM_IMPL(batchnorm) { | ||||
| 
 | ||||
|     auto input    = INPUT_VARIABLE(0);  // 2D:nc, 4D:nchw, 5D:ncdhw
 | ||||
|     auto mean     = INPUT_VARIABLE(1);  // [c]
 | ||||
|     auto variance = INPUT_VARIABLE(2);  // [c]
 | ||||
|     NDArray* gamma    = nullptr;        // [c]
 | ||||
|     NDArray* beta     = nullptr;        // [c]
 | ||||
| 
 | ||||
|     auto output = OUTPUT_VARIABLE(0);   // same shape as input
 | ||||
| 
 | ||||
|     const bool   applyScale  = (bool)INT_ARG(0); | ||||
|     const bool   applyOffset = (bool)INT_ARG(1); | ||||
|     const double epsilon     = T_ARG(0); | ||||
| 
 | ||||
|     if(applyScale) | ||||
|         gamma = INPUT_VARIABLE(3); | ||||
|     if(applyOffset) | ||||
|         beta = INPUT_VARIABLE(3 + (int)applyScale); | ||||
| 
 | ||||
|     const int numOfIntArgs = block.getIArguments()->size(); | ||||
|     const int inRank = input->rankOf(); | ||||
| 
 | ||||
|     // get axes args to normalize input array over
 | ||||
|     std::vector<int> axes; | ||||
|     if(numOfIntArgs > 2) | ||||
|         for(int i = 2; i < numOfIntArgs; ++i) | ||||
|             axes.push_back(INT_ARG(i)); | ||||
|     else | ||||
|         axes.push_back(inRank-1);               // default dimension to reduce along is last dimension
 | ||||
| 
 | ||||
|     const int numOfAxes = axes.size(); | ||||
|     REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); | ||||
|     REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); | ||||
|     REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); | ||||
|     REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); | ||||
|     if(gamma != nullptr) | ||||
|         REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); | ||||
|     if(beta != nullptr) | ||||
|         REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); | ||||
| 
 | ||||
|     // types of all input arrays should be the same (except dLdO)
 | ||||
|     for(int i = 1; i < block.width() - 1; ++i) | ||||
|         REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_MKLDNN op: types of all input arrays should be the same !"); | ||||
| 
 | ||||
| 
 | ||||
|     NDArray *weights = nullptr; | ||||
| 
 | ||||
|     if(applyScale || applyOffset) { | ||||
| 
 | ||||
|         weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); | ||||
| 
 | ||||
|         if(applyScale) | ||||
|             (*weights)({0,1, 0,0}).assign(gamma); | ||||
|         else | ||||
|             (*weights)({0,1, 0,0}).assign(1); | ||||
|         if(applyOffset) | ||||
|             (*weights)({1,2, 0,0}).assign(beta); | ||||
|         else | ||||
|             (*weights)({1,2, 0,0}).assign(0); | ||||
|     } | ||||
| 
 | ||||
|     batchnormMKLDNN(input, mean, variance, weights, epsilon, output); | ||||
| 
 | ||||
|     delete weights; | ||||
| 
 | ||||
|     return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| PLATFORM_CHECK(batchnorm) { | ||||
|     // we don't want to use mkldnn if cpu doesn't support avx/avx2
 | ||||
|     // if (::optimalLevel() < 2)
 | ||||
|     //     return false;
 | ||||
| 
 | ||||
|     auto input    = INPUT_VARIABLE(0);  // 2D:nc, 4D:nchw, 5D:ncdhw
 | ||||
|     auto mean     = INPUT_VARIABLE(1);  // [c]
 | ||||
|     auto variance = INPUT_VARIABLE(2);  // [c]
 | ||||
|     NDArray* gamma    = nullptr;        // [c]
 | ||||
|     NDArray* beta     = nullptr;        // [c]
 | ||||
| 
 | ||||
|     auto output = OUTPUT_VARIABLE(0);   // same shape as input
 | ||||
| 
 | ||||
|     const bool   applyScale  = (bool)INT_ARG(0); | ||||
|     const bool   applyOffset = (bool)INT_ARG(1); | ||||
| 
 | ||||
|     if(applyScale) | ||||
|         gamma = INPUT_VARIABLE(3); | ||||
|     if(applyOffset) | ||||
|         beta = INPUT_VARIABLE(3 + (int)applyScale); | ||||
| 
 | ||||
| 
 | ||||
|     const int numOfIntArgs = block.getIArguments()->size(); | ||||
|     std::vector<int> axes; | ||||
|     if(numOfIntArgs > 2) | ||||
|         for(int i = 2; i < numOfIntArgs; ++i) | ||||
|             axes.push_back(INT_ARG(i)); | ||||
|     else | ||||
|         axes.push_back(input->rankOf()-1);               // default dimension to reduce along is last dimension
 | ||||
| 
 | ||||
|     DataType inputType = input->dataType(); | ||||
|     DataType meanType  = mean->dataType(); | ||||
|     DataType varType   = variance->dataType(); | ||||
|     DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; | ||||
|     DataType betaType  = beta  != nullptr ? beta->dataType()  : DataType::FLOAT32; | ||||
|     DataType outType   = output->dataType(); | ||||
| 
 | ||||
|     const int inRank = input->rankOf(); | ||||
| 
 | ||||
|     return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && | ||||
|             (inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 && | ||||
|              gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| // PLATFORM_IMPL(batchnorm) {
 | ||||
| 
 | ||||
| //     auto input = INPUT_VARIABLE(0);
 | ||||
| //     auto mean = INPUT_VARIABLE(1);
 | ||||
| //     auto variance = INPUT_VARIABLE(2);
 | ||||
| //     NDArray *gamma = nullptr;
 | ||||
| //     NDArray *beta = nullptr;
 | ||||
| 
 | ||||
| //     auto output = OUTPUT_VARIABLE(0);
 | ||||
| 
 | ||||
| //     const bool applyScale = (bool) INT_ARG(0);
 | ||||
| //     const bool applyOffset = (bool) INT_ARG(1);
 | ||||
| //     const double epsilon = T_ARG(0);
 | ||||
| 
 | ||||
| //     if (applyScale)
 | ||||
| //         gamma = INPUT_VARIABLE(3);
 | ||||
| //     if (applyOffset)
 | ||||
| //         beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
 | ||||
| 
 | ||||
| //     std::vector<int> axes;
 | ||||
| //     if (block.numI() > 2)
 | ||||
| //         for (int i = 2; i < block.numI(); ++i)
 | ||||
| //             axes.push_back(INT_ARG(i));
 | ||||
| //     else
 | ||||
| //         axes.push_back(input->rankOf() - 1);
 | ||||
| 
 | ||||
| //     std::vector<Nd4jLong> shape({2, mean->lengthOf()});
 | ||||
| //     NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
 | ||||
| //     weights({0, 1, 0, 0}).assign(1.0f);
 | ||||
| //     weights({1, 2, 0, 0}).assign(0.0f);
 | ||||
| 
 | ||||
| //     mkldnn_memory_desc_t empty;
 | ||||
| //     mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
 | ||||
| 
 | ||||
| //     auto flag = mkldnn::normalization_flags::use_global_stats;
 | ||||
| //     if (applyScale || applyOffset)
 | ||||
| //         flag |= mkldnn::normalization_flags::use_scale_shift;
 | ||||
| 
 | ||||
| //     mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
 | ||||
| //                                               &batchnorm_src_md, nullptr, &batchnorm_dst_md,
 | ||||
| //                                               &user_src_md, nullptr, &user_dst_md, axes[0]);
 | ||||
| 
 | ||||
| //     auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
 | ||||
| 
 | ||||
| //     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
 | ||||
| //     mkldnn::stream stream(engine);
 | ||||
| //     auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
 | ||||
| //     auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
 | ||||
| //     auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
 | ||||
| //     auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
 | ||||
| //                                                 mean->buffer());
 | ||||
| //     auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
 | ||||
| //                                                     variance->buffer());
 | ||||
| //     auto batchnorm_src_memory = user_src_memory;
 | ||||
| //     mkldnn::memory m(batchnorm_src_md, engine);
 | ||||
| //     if (m.get_desc() != user_src_memory.get_desc()) {
 | ||||
| //         batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
 | ||||
| //         mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
 | ||||
| //                                                                batchnorm_src_memory);
 | ||||
| //     }
 | ||||
| //     auto batchnorm_dst_memory = user_dst_memory;
 | ||||
| //     if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
 | ||||
| //         batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
 | ||||
| //     }
 | ||||
| //     if (applyScale || applyOffset) {
 | ||||
| //         if (gamma != nullptr) {
 | ||||
| //             weights({0, 1, 0, 0}).assign(gamma);
 | ||||
| //         }
 | ||||
| //         if (beta != nullptr) {
 | ||||
| //             weights({1, 2, 0, 0}).assign(beta);
 | ||||
| //         }
 | ||||
| 
 | ||||
| //         auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
 | ||||
| //         mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
 | ||||
| //                                                                  {{MKLDNN_ARG_SRC,      batchnorm_src_memory},
 | ||||
| //                                                                   {MKLDNN_ARG_MEAN,     batchnorm_mean_memory},
 | ||||
| //                                                                   {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
 | ||||
| //                                                                   {MKLDNN_ARG_WEIGHTS,  batchnorm_weights_memory},
 | ||||
| //                                                                   {MKLDNN_ARG_DST,      batchnorm_dst_memory}});
 | ||||
| //     } else {
 | ||||
| //         mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
 | ||||
| //                                                                  {{MKLDNN_ARG_SRC,      batchnorm_src_memory},
 | ||||
| //                                                                   {MKLDNN_ARG_MEAN,     batchnorm_mean_memory},
 | ||||
| //                                                                   {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
 | ||||
| //                                                                   {MKLDNN_ARG_DST,      batchnorm_dst_memory}});
 | ||||
| //     }
 | ||||
| //     if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
 | ||||
| //         mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
 | ||||
| //                                                                user_dst_memory);
 | ||||
| //     }
 | ||||
| //     stream.wait();
 | ||||
| 
 | ||||
| //     return Status::OK();
 | ||||
| // }
 | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| // PLATFORM_CHECK(batchnorm) {
 | ||||
| //     // we don't want to use mkldnn if cpu doesn't support avx/avx2
 | ||||
| //     if (::optimalLevel() < 2)
 | ||||
| //         return false;
 | ||||
| 
 | ||||
| //     auto input = INPUT_VARIABLE(0);
 | ||||
| //     auto mean = INPUT_VARIABLE(1);
 | ||||
| //     auto variance = INPUT_VARIABLE(2);
 | ||||
| //     NDArray *gamma = nullptr;
 | ||||
| //     NDArray *beta = nullptr;
 | ||||
| 
 | ||||
| //     auto output = OUTPUT_VARIABLE(0);
 | ||||
| 
 | ||||
| //     const bool applyScale = (bool) INT_ARG(0);
 | ||||
| //     const bool applyOffset = (bool) INT_ARG(1);
 | ||||
| //     const double epsilon = T_ARG(0);
 | ||||
| 
 | ||||
| //     if (applyScale)
 | ||||
| //         gamma = INPUT_VARIABLE(3);
 | ||||
| //     if (applyOffset)
 | ||||
| //         beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
 | ||||
| 
 | ||||
| //     std::vector<int> axes;
 | ||||
| //     if (block.numI() > 2)
 | ||||
| //         for (int i = 2; i < block.numI(); ++i)
 | ||||
| //             axes.push_back(INT_ARG(i));
 | ||||
| //     else
 | ||||
| //         axes.push_back(input->rankOf() - 1);
 | ||||
| 
 | ||||
| //     return block.isUseMKLDNN() &&
 | ||||
| //            nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
 | ||||
| //            axes.size() == 1;
 | ||||
| // }
 | ||||
| 
 | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| PLATFORM_IMPL(batchnorm_bp) { | ||||
| 
 | ||||
|     NDArray* input    = INPUT_VARIABLE(0);      // 2D:nc, 4D:nchw, 5D:ncdhw
 | ||||
|     NDArray* mean     = INPUT_VARIABLE(1);      // [c]
 | ||||
|     NDArray* variance = INPUT_VARIABLE(2);      // [c]
 | ||||
|     NDArray* dLdO     = INPUT_VARIABLE(3);      // same as input
 | ||||
|     NDArray* gamma    = nullptr;                // [c]
 | ||||
|     NDArray* beta     = nullptr;                // [c]
 | ||||
| 
 | ||||
|     NDArray* dLdI = OUTPUT_VARIABLE(0);         // same as input
 | ||||
|     NDArray* dLdM = OUTPUT_VARIABLE(1);         // [c]
 | ||||
|     NDArray* dLdV = OUTPUT_VARIABLE(2);         // [c]
 | ||||
|     NDArray* dLdG = nullptr;                    // [c]
 | ||||
|     NDArray* dLdB = nullptr;                    // [c]
 | ||||
| 
 | ||||
|     const bool  applyScale  = (bool)INT_ARG(0); | ||||
|     const bool  applyOffset = (bool)INT_ARG(1); | ||||
|     const float epsilon     = T_ARG(0); | ||||
| 
 | ||||
|     if(applyScale) { | ||||
|         gamma = INPUT_VARIABLE(4); | ||||
|         dLdG  = OUTPUT_VARIABLE(3); | ||||
|     } | ||||
|     if(applyOffset) { | ||||
|         beta = INPUT_VARIABLE(4 + (int)applyScale); | ||||
|         dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); | ||||
|     } | ||||
| 
 | ||||
|     const int numOfIntArgs = block.getIArguments()->size(); | ||||
|     const int inRank = input->rankOf(); | ||||
| 
 | ||||
|     // get axes args to normalize input array over
 | ||||
|     std::vector<int> axes; | ||||
|     if(numOfIntArgs > 2) | ||||
|         for(int i = 2; i < numOfIntArgs; ++i) | ||||
|             axes.push_back(INT_ARG(i)); | ||||
|     else | ||||
|         axes.push_back(inRank-1);               // default dimension to reduce along is last dimension
 | ||||
| 
 | ||||
|     const int numOfAxes = axes.size(); | ||||
|     REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes); | ||||
|     REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank); | ||||
|     REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str()); | ||||
|     REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str()); | ||||
|     REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str()); | ||||
|     if(gamma != nullptr) | ||||
|         REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str()); | ||||
|     if(beta != nullptr) | ||||
|         REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str()); | ||||
| 
 | ||||
|     // types of all input arrays should be the same (except dLdO)
 | ||||
|     for(int i = 1; i < block.width() - 1; ++i) | ||||
|         REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !"); | ||||
| 
 | ||||
| 
 | ||||
|     NDArray *weights = nullptr, *dLdW = nullptr; | ||||
| 
 | ||||
|     if(applyScale || applyOffset) { | ||||
|         weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); | ||||
|         dLdW    = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType()); | ||||
|         if(applyScale) | ||||
|             (*weights)({0,1, 0,0}).assign(gamma); | ||||
|         else | ||||
|             (*weights)({0,1, 0,0}).assign(1); | ||||
|         if(applyOffset) | ||||
|             (*weights)({1,2, 0,0}).assign(beta); | ||||
|         else | ||||
|             (*weights)({1,2, 0,0}).assign(0); | ||||
|     } | ||||
| 
 | ||||
|     *dLdM = 0; | ||||
|     *dLdV = 0; | ||||
| 
 | ||||
|     batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW); | ||||
| 
 | ||||
|     if(applyScale || applyOffset) { | ||||
|         if(applyScale) | ||||
|             dLdG->assign((*dLdW)({0,1, 0,0})); | ||||
|         if(applyOffset) | ||||
|             dLdB->assign((*dLdW)({1,2, 0,0})); | ||||
| 
 | ||||
|         delete weights; | ||||
|         delete dLdW; | ||||
|     } | ||||
| 
 | ||||
|     return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| PLATFORM_CHECK(batchnorm_bp) { | ||||
|     // we don't want to use mkldnn if cpu doesn't support avx/avx2
 | ||||
|     // if (::optimalLevel() < 2)
 | ||||
|     //     return false;
 | ||||
| 
 | ||||
|     NDArray* input    = INPUT_VARIABLE(0);      // 2D:nc, 4D:nchw, 5D:ncdhw
 | ||||
|     NDArray* mean     = INPUT_VARIABLE(1);      // [c]
 | ||||
|     NDArray* variance = INPUT_VARIABLE(2);      // [c]
 | ||||
|     NDArray* dLdO     = INPUT_VARIABLE(3);      // same as input
 | ||||
|     NDArray* gamma    = nullptr;                // [c]
 | ||||
|     NDArray* beta     = nullptr;                // [c]
 | ||||
| 
 | ||||
|     NDArray* dLdI = OUTPUT_VARIABLE(0);         // same as input
 | ||||
|     NDArray* dLdM = OUTPUT_VARIABLE(1);         // [c]
 | ||||
|     NDArray* dLdV = OUTPUT_VARIABLE(2);         // [c]
 | ||||
|     NDArray* dLdG = nullptr;                    // [c]
 | ||||
|     NDArray* dLdB = nullptr;                    // [c]
 | ||||
| 
 | ||||
|     const bool  applyScale  = (bool)INT_ARG(0); | ||||
|     const bool  applyOffset = (bool)INT_ARG(1); | ||||
| 
 | ||||
|     if(applyScale) { | ||||
|         gamma = INPUT_VARIABLE(4); | ||||
|         dLdG  = OUTPUT_VARIABLE(3); | ||||
|     } | ||||
|     if(applyOffset) { | ||||
|         beta = INPUT_VARIABLE(4 + (int)applyScale); | ||||
|         dLdB = OUTPUT_VARIABLE(3 + (int)applyScale); | ||||
|     } | ||||
| 
 | ||||
|     const int numOfIntArgs = block.getIArguments()->size(); | ||||
|     std::vector<int> axes; | ||||
|     if(numOfIntArgs > 2) | ||||
|         for(int i = 2; i < numOfIntArgs; ++i) | ||||
|             axes.push_back(INT_ARG(i)); | ||||
|     else | ||||
|         axes.push_back(input->rankOf()-1);               // default dimension to reduce along is last dimension
 | ||||
| 
 | ||||
|     DataType inputType = input->dataType(); | ||||
|     DataType meanType  = mean->dataType(); | ||||
|     DataType varType   = variance->dataType(); | ||||
|     DataType dLdOType  = dLdO->dataType(); | ||||
|     DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32; | ||||
|     DataType betaType  = beta  != nullptr ? beta->dataType()  : DataType::FLOAT32; | ||||
| 
 | ||||
|     DataType dLdIType = dLdI->dataType(); | ||||
|     DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32; | ||||
|     DataType dLdBType = beta  != nullptr ? dLdB->dataType() : DataType::FLOAT32; | ||||
| 
 | ||||
|     const int inRank = input->rankOf(); | ||||
| 
 | ||||
|     return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) && | ||||
|             (inputType == DataType::FLOAT32 && meanType  == DataType::FLOAT32 && varType  == DataType::FLOAT32 && | ||||
|              dLdOType  == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && | ||||
|              dLdIType  == DataType::FLOAT32 && dLdGType  == DataType::FLOAT32 && dLdBType == DataType::FLOAT32); | ||||
| } | ||||
| 
 | ||||
| } | ||||
| } | ||||
| } | ||||
| @ -132,8 +132,6 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* | ||||
| 
 | ||||
|     auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||
| 
 | ||||
|     mkldnn_memory_desc_t empty; | ||||
| 
 | ||||
|     mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, | ||||
|                          x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; | ||||
| 
 | ||||
|  | ||||
| @ -305,50 +305,50 @@ namespace nd4j { | ||||
|         }; | ||||
| 
 | ||||
| 
 | ||||
|         void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, | ||||
|                                           mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md, | ||||
|                                           mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) { | ||||
|             const Nd4jLong* shape = src->getShapeInfo(); | ||||
|             Nd4jLong rank = shape[0]; | ||||
|             Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
 | ||||
|             Nd4jLong dim2 = axis >= 2 ? 1 : 2; | ||||
|             Nd4jLong dim3 = axis >= 3 ? 2 : 3; | ||||
|             mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; | ||||
|         // void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
 | ||||
|         //                                   mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
 | ||||
|         //                                   mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
 | ||||
|         //     const Nd4jLong* shape = src->getShapeInfo();
 | ||||
|         //     Nd4jLong rank = shape[0];
 | ||||
|         //     Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
 | ||||
|         //     Nd4jLong dim2 = axis >= 2 ? 1 : 2;
 | ||||
|         //     Nd4jLong dim3 = axis >= 3 ? 2 : 3;
 | ||||
|         //     mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
 | ||||
| 
 | ||||
|             auto type = mkldnn::memory::data_type::f32; | ||||
|             auto format = mkldnn::memory::format_tag::nchw; | ||||
|             auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
 | ||||
|         //     auto type = mkldnn::memory::data_type::f32;
 | ||||
|         //     auto format = mkldnn::memory::format_tag::nchw;
 | ||||
|         //     auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
 | ||||
| 
 | ||||
|             if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { | ||||
|                 *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); | ||||
|                 *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); | ||||
|                 user_src_md->data.format_kind = mkldnn_blocked; // overrides format
 | ||||
|                 user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; | ||||
|                 user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; | ||||
|                 user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; | ||||
|                 user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; | ||||
|             } | ||||
|         //     if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
 | ||||
|         //         *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
 | ||||
|         //         *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
 | ||||
|         //         user_src_md->data.format_kind = mkldnn_blocked; // overrides format
 | ||||
|         //         user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
 | ||||
|         //         user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
 | ||||
|         //         user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
 | ||||
|         //         user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
 | ||||
|         //     }
 | ||||
| 
 | ||||
|             if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { | ||||
|                 *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); | ||||
|                 *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); | ||||
|                 user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
 | ||||
|                 user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; | ||||
|                 user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; | ||||
|                 user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; | ||||
|                 user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; | ||||
|             } | ||||
|         //     if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
 | ||||
|         //         *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
 | ||||
|         //         *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
 | ||||
|         //         user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
 | ||||
|         //         user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
 | ||||
|         //         user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
 | ||||
|         //         user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
 | ||||
|         //         user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
 | ||||
|         //     }
 | ||||
| 
 | ||||
|             if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { | ||||
|                 *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); | ||||
|                 *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format); | ||||
|                 user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
 | ||||
|                 user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; | ||||
|                 user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; | ||||
|                 user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; | ||||
|                 user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; | ||||
|             } | ||||
|         }; | ||||
|         //     if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
 | ||||
|         //         *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
 | ||||
|         //         *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
 | ||||
|         //         user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
 | ||||
|         //         user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
 | ||||
|         //         user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
 | ||||
|         //         user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
 | ||||
|         //         user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
 | ||||
|         //     }
 | ||||
|         // };
 | ||||
| 
 | ||||
| 
 | ||||
|         void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, | ||||
|  | ||||
| @ -62,7 +62,9 @@ namespace nd4j{ | ||||
| 
 | ||||
|             DECLARE_PLATFORM(lrn); | ||||
| 
 | ||||
|             DECLARE_PLATFORM(batchnorm_new); | ||||
|             DECLARE_PLATFORM(batchnorm); | ||||
| 
 | ||||
|             DECLARE_PLATFORM(batchnorm_bp); | ||||
| 
 | ||||
|             DECLARE_PLATFORM(lstmLayer); | ||||
|         } | ||||
|  | ||||
| @ -413,7 +413,7 @@ namespace nd4j { | ||||
|             return ctx; | ||||
|         }; | ||||
| 
 | ||||
|         nd4j::ops::batchnorm_new batchnorm; | ||||
|         nd4j::ops::batchnorm batchnorm; | ||||
|         DeclarableBenchmark benchmark(batchnorm, "batchnorm"); | ||||
|         output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization"); | ||||
| 
 | ||||
| @ -1822,7 +1822,7 @@ namespace nd4j { | ||||
|         std::string result; | ||||
| 
 | ||||
|         long start = nowMs(); | ||||
|          | ||||
| 
 | ||||
|         // set 1
 | ||||
|         nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", ""); | ||||
|         result += fastScalarBenchmark(); | ||||
|  | ||||
| @ -2385,129 +2385,6 @@ TEST_F(DeclarableOpsTests1, CompactLaunchTests2) { | ||||
|     ASSERT_TRUE(exp.equalsTo(&z)); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, batchnorm_test1) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     auto mean     = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     auto variance = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     auto gamma    = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     auto beta     = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
| 
 | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011,10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369}); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     mean.assign(1.); | ||||
|     variance.assign(0.5); | ||||
|     gamma.assign(1.2); | ||||
|     beta.assign(1.); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto output = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShapeStrict(output)); | ||||
|     ASSERT_TRUE(expected.equalsTo(output)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| TEST_F(DeclarableOpsTests1, batchnorm_test2) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<double>('c', {2,3,1,3,1}); | ||||
|     auto mean     = NDArrayFactory::create<double>('c', {1,3,2,1,2}); | ||||
|     auto variance = NDArrayFactory::create<double>('c', {2,1,2,3,2}); | ||||
|     auto gamma    = NDArrayFactory::create<double>('c', {2,3,2,3,1}); | ||||
|     auto beta     = NDArrayFactory::create<double>('c', {1,3,2,1,2}); | ||||
| 
 | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 1.        , 1.        , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1.        , 1.        , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144}); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     mean.assign(1.); | ||||
|     variance.assign(0.5); | ||||
|     gamma.assign(1.2); | ||||
|     beta.assign(1.); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto output = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShapeStrict(output)); | ||||
|     ASSERT_TRUE(expected.equalsTo(output)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, batchnorm_test3) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     auto mean     = NDArrayFactory::create<double>('c', {2,3,2}); | ||||
|     auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,1}); | ||||
|     auto gamma    = NDArrayFactory::create<double>('c', {1,1}); | ||||
|     auto beta     = NDArrayFactory::create<double>('c', {1,2}); | ||||
| 
 | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011, 10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369}); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     mean.assign(1.); | ||||
|     variance.assign(0.5); | ||||
|     gamma.assign(1.2); | ||||
|     beta.assign(1.); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto output = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShapeStrict(output)); | ||||
|     ASSERT_TRUE(expected.equalsTo(output)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, batchnorm_test4) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<double>('c', {3,2}); | ||||
|     auto mean    = NDArrayFactory::create<double>('c', {2,3,2}); | ||||
|     auto variance= NDArrayFactory::create<double>('c', {2,3,1,3,2}); | ||||
|     auto gamma   = NDArrayFactory::create<double>('c', {1,1}); | ||||
|     auto beta    = NDArrayFactory::create<double>('c', {1,2}); | ||||
| 
 | ||||
|     auto expected= NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428}); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     mean.assign(1.); | ||||
|     variance.assign(0.5); | ||||
|     gamma.assign(1.2); | ||||
|     beta.assign(1.); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto output = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShapeStrict(output)); | ||||
|     ASSERT_TRUE(expected.equalsTo(output)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| // TEST_F(DeclarableOpsTests1, sru_old_test1) {
 | ||||
|  | ||||
| @ -2313,7 +2313,35 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { | ||||
| TEST_F(DeclarableOpsTests10, batchnorm_test1) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1},  nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, {10, 20, -10, -20},     nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expected('c', {2,4}, {11.61218734,  18.52390321,  -8.67185076, -21.28716864, 10.93337162,  19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto output = results->at(0); | ||||
|     // output->printBuffer();
 | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShapeStrict(output)); | ||||
|     ASSERT_TRUE(expected.equalsTo(output)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {2,3,4}); | ||||
|     auto mean     = NDArrayFactory::create<TypeParam>('c', {4}); | ||||
| @ -2330,7 +2358,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { | ||||
|     gamma.assign(1.2); | ||||
|     beta.assign(1.); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_new op; | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
| @ -2346,7 +2374,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) { | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { | ||||
| TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {2,3,4}); | ||||
|     auto mean     = NDArrayFactory::create<TypeParam>('c', {3}, {1.05, 1.1, 1.15}); | ||||
| @ -2359,7 +2387,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_new op; | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); | ||||
| 
 | ||||
| @ -2374,7 +2402,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) { | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { | ||||
| TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {2,3,4}); | ||||
|     auto mean     = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); | ||||
| @ -2387,7 +2415,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_new op; | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); | ||||
| 
 | ||||
| @ -2401,6 +2429,63 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) { | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests10, batchnorm_test5) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1},  nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, {10, 20, -10, -20},     nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expected('c', {2,4,2,2}, {11.612187,  11.442483, 11.272779,  11.103076, 18.990039,  19.145418, 19.300796,  19.456175, -9.557284,  -9.704856, -9.852428, -10., -20., | ||||
|                                       -19.856981, -19.713963, -19.570944, 8.896924,   8.727221, 8.557517,   8.387813, 21.476097,  21.631475, 21.786854,  21.942233, -11.918438, | ||||
|                                       -12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32); | ||||
|     input.linspace(0.1, 0.1); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto output = results->at(0); | ||||
|     // output->printBuffer();
 | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShapeStrict(output)); | ||||
|     ASSERT_TRUE(expected.equalsTo(output)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests10, batchnorm_test6) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1},  nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, {10, 20, -10, -20},     nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expected('c', {2,2,2,4}, {11.612187,  18.523903,  -8.671851, -21.287169, 10.933372,  19.145418,  -9.262139, -20.715094, 10.254556,  19.766932,  -9.852428, -20.143019, 9.57574 , | ||||
|                                     20.388447, -10.442716, -19.570944,8.896924,  21.009961, -11.033005, -18.998869, 8.218109,  21.631475, -11.623294, -18.426794, 7.539293,  22.25299 , | ||||
|                                     -12.213582, -17.854719, 6.860477,  22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32); | ||||
|     input.linspace(0.1, 0.1); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto output = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShapeStrict(output)); | ||||
|     ASSERT_TRUE(expected.equalsTo(output)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ///////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { | ||||
| 
 | ||||
|  | ||||
| @ -2883,78 +2883,336 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<double>('c', {3,2}); | ||||
|     auto mean     = NDArrayFactory::create<double>('c', {2,3,2}); | ||||
|     auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,2}); | ||||
|     auto gamma    = NDArrayFactory::create<double>('c', {1,1}); | ||||
|     auto beta     = NDArrayFactory::create<double>('c', {1,2}); | ||||
|     auto dLdO     = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     NDArray input   ('c', {2,3,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gradO   ('c', {2,3,4}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556,  0.,  0.254556,0.509112,  0.763668,  1.018224,  1.272779, | ||||
|                                 1.527335,  1.781891,  2.036447,  2.291003,2.545559,  2.800115,  3.054671,  3.309227,3.563783,  3.818338,  4.072894,  4.32745}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     mean.assign(1.); | ||||
|     variance.assign(0.5); | ||||
|     gamma.assign(1.2); | ||||
|     beta.assign(1.); | ||||
|     // beta.assign(1.);     // has no effect on gradient calculations
 | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); | ||||
|     const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &beta, &dLdO}, {1e-5}, {1,1}); | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     nd4j::ops::batchnorm opFF; | ||||
|     nd4j::ops::batchnorm_bp opBP; | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     ASSERT_TRUE(isGradCorrect); | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     auto mean     = NDArrayFactory::create<double>('c', {2,3,2}); | ||||
|     auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,1}); | ||||
|     auto gamma    = NDArrayFactory::create<double>('c', {1,1}); | ||||
|     auto dLdO     = NDArrayFactory::create<double>('c', {2,3,2,3,2}); | ||||
|     NDArray input   ('c', {2,3,4}, nd4j::DataType::DOUBLE); | ||||
|     NDArray mean    ('c', {3}, {1.05, 1.1, 1.15}); | ||||
|     NDArray variance('c', {3}, {0.5, 0.6, 0.7}); | ||||
|     NDArray gamma   ('c', {3}, {1.2, 1.3, 1.4}); | ||||
|     NDArray beta    ('c', {3}, nd4j::DataType::DOUBLE); | ||||
|     NDArray gradO   ('c', {2,3,4}, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742,  0.,  0.251742,0.501992,  0.752989,  1.003985,  1.254981, | ||||
|                                     1.527335,  1.781891,  2.036447,  2.291003,2.517418,  2.76916 ,  3.020902,  3.272644,3.513947,  3.764943,  4.015939,  4.266936}); | ||||
|     NDArray expdLdG('c', {3}, {5.81236 ,  7.048771, 12.155388}); | ||||
|     NDArray expdLdB('c', {3}, {1.8,  6.6, 11.4}); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     mean.assign(1.); | ||||
|     variance.assign(0.5); | ||||
|     gamma.assign(1.2); | ||||
|     // beta.assign(1.);     // has no effect on gradient calculations
 | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma}, {1e-5}, {1,0}); | ||||
|     const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &dLdO}, {1e-5}, {1,0}); | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     nd4j::ops::batchnorm opFF; | ||||
|     nd4j::ops::batchnorm_bp opBP; | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); | ||||
| 
 | ||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     ASSERT_TRUE(isGradCorrect); | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { | ||||
| 
 | ||||
|     auto input    = NDArrayFactory::create<double>('c', {2,3,1,3}); | ||||
|     auto mean     = NDArrayFactory::create<double>('c', {1,3,2,1}); | ||||
|     auto variance = NDArrayFactory::create<double>('c', {2,1,2,3}); | ||||
|     auto dLdO     = NDArrayFactory::create<double>('c', {2,3,2,3}); | ||||
|     NDArray input   ('c', {2,3,4}, nd4j::DataType::DOUBLE); | ||||
|     NDArray mean    ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); | ||||
|     NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); | ||||
|     NDArray gamma   ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}); | ||||
|     NDArray beta    ('c', {2,1,4}, nd4j::DataType::DOUBLE); | ||||
|     NDArray gradO   ('c', {2,3,4}, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742,  0.,  0.251556,0.509112,  0.755225,  1.003985,  1.25778 , | ||||
|                                    1.517885,  1.784991,  2.05947 ,  2.341504,2.529808,  2.804986,  3.089205,  3.382173,3.541731,  3.824981,  4.11894 ,  4.422841}); | ||||
|     NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }); | ||||
|     NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45,  0.  ,  0.45,  4.5 ,  4.95,  5.4 ,  5.85}); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     mean.assign(1.); | ||||
|     variance.assign(0.5); | ||||
|     // beta.assign(1.);     // has no effect on gradient calculations
 | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     const OpArgsHolder argsHolderFF({&input, &mean, &variance}, {1e-5}, {0,0}); | ||||
|     const OpArgsHolder argsHolderBP({&input, &mean, &variance, &dLdO}, {1e-5}, {0,0}); | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     nd4j::ops::batchnorm opFF; | ||||
|     nd4j::ops::batchnorm_bp opBP; | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2}); | ||||
| 
 | ||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     ASSERT_TRUE(isGradCorrect); | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gradO   ('c', {2,4}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 ,  0.885433, -0.643584,  0.509112, -0.233068, -0.,  0.214528}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdG('c', {4}, {1.442483, 0.9502  , 0.569207, 0.314641}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gradO   ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,4,2,2}, {1.527335,  1.272779,1.018224,  0.763668,-0.466136, -0.233068,0.,  0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169,  1.501697,1.716225,  1.930753, | ||||
|                                     -2.545559, -2.800115,-3.054671, -3.309227,3.262951,  3.496019,3.729087,  3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618,  4.934146,5.148675,  5.363203}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdB('c', {4}, {4.2,  9. , 13.8, 18.6}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gradO   ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 ,  0.885433, -0.643584, 0.509112, -0.233068, -0.,  0.214528, -0.509112,  0.699204, -0.885433,  1.072641, -1.527335,  1.631475, -1.770866,  1.930753, | ||||
|                                     -2.545559,  2.563747, -2.656298,  2.788865, -3.563783,  3.496019, -3.541731,  3.646978, -4.582006,  4.42829 , -4.427164,  4.50509 , -5.60023 ,  5.360562, -5.312597,  5.363203}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gradO   ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,2,2,2,4}, {1.527335,  -1.16534 ,   0.885433,  -0.643584,0.509112,  -0.233068,  -0.,   0.214528,-0.509112,   0.699204,  -0.885433,   1.072641,-1.527335,   1.631475,  -1.770866, | ||||
|                                       1.930753,-2.545559,   2.563747,  -2.656298,   2.788865,-3.563783,   3.496019,  -3.541731,   3.646978,-4.582006,   4.42829 ,  -4.427164, | ||||
|                                       4.50509 ,-5.60023 ,   5.360562,  -5.312597,   5.363203,  -6.618453,   6.292834,  -6.19803 ,   6.221315,-7.636677,   7.225105,  -7.083463, | ||||
|                                       7.079428,-8.6549  ,   8.157377,  -7.968895,   7.93754 ,-9.673124,   9.089649,  -8.854328,   8.795652, -10.691348,  10.02192 ,  -9.739761, | ||||
|                                       9.653765,-11.709571,  10.954192, -10.625194,  10.511877,-12.727795,  11.886464, -11.510627,  11.36999 ,-13.746018,  12.818735, -12.39606 ,  12.228102}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     // dLdI->printBuffer();
 | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) { | ||||
| 
 | ||||
|     NDArray input   ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); | ||||
|     NDArray mean    ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); | ||||
|     NDArray variance('c', {4}, {0.5, 0.7, 0.9,  1.1}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gamma   ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); | ||||
|     NDArray beta    ('c', {4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray gradO   ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expdLdI('c', {2,4,2,2,2}, {1.527335,   1.272779, 1.018224,   0.763668, 0.509112,   0.254556, -0.      ,  -0.254556, 0.466136,   0.699204, 0.932272,   1.16534 , 1.398407,   1.631475, 1.864543,   2.097611, | ||||
|                                     -2.213582,  -2.43494 , -2.656298,  -2.877657, -3.099015,  -3.320373, -3.541731,  -3.76309 , 3.861506,   4.076034, 4.290562,   4.50509 , 4.719618,   4.934146, 5.148675,   5.363203, | ||||
|                                     -6.618453,  -6.873009, -7.127565,  -7.382121, -7.636677,  -7.891233, -8.145789,  -8.400345, 7.924309,   8.157377, 8.390445,   8.623513, 8.856581,   9.089649, 9.322717,   9.555784, | ||||
|                                     -9.297045,  -9.518403, -9.739761,  -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405,  10.940933, 11.155462,  11.36999 , 11.584518,  11.799046, 12.013574,  12.228102}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32); | ||||
|     NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     input.linspace(0.1, 0.1); | ||||
|     gradO.linspace(-0.9, 0.15); | ||||
| 
 | ||||
|     nd4j::ops::batchnorm_bp op; | ||||
| 
 | ||||
|     auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto dLdI = results->at(0); | ||||
|     auto dLdG = results->at(3); | ||||
|     auto dLdB = results->at(4); | ||||
| 
 | ||||
|     // dLdI->printBuffer();
 | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI)); | ||||
|     ASSERT_TRUE(expdLdI.equalsTo(dLdI)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG)); | ||||
|     ASSERT_TRUE(expdLdG.equalsTo(dLdG)); | ||||
| 
 | ||||
|     ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB)); | ||||
|     ASSERT_TRUE(expdLdB.equalsTo(dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| /*
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { | ||||
|  | ||||
| @ -64,7 +64,7 @@ TEST_F(MklDnnTests, helpers_includer) { | ||||
|     nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp; | ||||
| 
 | ||||
|     nd4j::ops::platforms::PLATFORM_lrn lrn; | ||||
|     nd4j::ops::platforms::PLATFORM_batchnorm_new batchnorm; | ||||
|     nd4j::ops::platforms::PLATFORM_batchnorm batchnorm; | ||||
| 
 | ||||
|     printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm}); | ||||
| #endif | ||||
|  | ||||
| @ -142,7 +142,7 @@ public class BatchNorm extends DynamicCustomOp { | ||||
| 
 | ||||
|     @Override | ||||
|     public String opName() { | ||||
|         return "batchnorm_new"; | ||||
|         return "batchnorm"; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user