Shyrma bn mkl bp (#14)
* - write code for new batchnorm backprop Signed-off-by: Yurii <iuriish@yahoo.com> * - testing batchnorm backprop Signed-off-by: Yurii <iuriish@yahoo.com> * - write code for batchnorm backprop based on mkl dnn api Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in batchnorm_bp mkl dnn Signed-off-by: Yurii <iuriish@yahoo.com> * - made corrections required by reviewer Signed-off-by: Yurii <iuriish@yahoo.com> * - change name in java wrapper for batchnorm op Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
d333d29099
commit
029a69a835
|
@ -60,6 +60,7 @@ namespace nd4j {
|
||||||
Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor);
|
Nd4jLong* createShapeInfo(const ShapeDescriptor &descriptor);
|
||||||
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape);
|
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape);
|
||||||
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
|
||||||
|
Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo);
|
||||||
|
|
||||||
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, nd4j::memory::Workspace *workspace);
|
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, nd4j::memory::Workspace *workspace);
|
||||||
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);
|
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);
|
||||||
|
|
|
@ -99,6 +99,10 @@ namespace nd4j {
|
||||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) {
|
||||||
|
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
|
||||||
|
}
|
||||||
|
|
||||||
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) {
|
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) {
|
||||||
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
|
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
|
||||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
|
|
|
@ -102,6 +102,10 @@ namespace nd4j {
|
||||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const Nd4jLong* shapeInfo) {
|
||||||
|
return ConstantShapeHelper::createShapeInfo(dataType, shape::order(shapeInfo), shape::rank(shapeInfo), shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)));
|
||||||
|
}
|
||||||
|
|
||||||
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) {
|
Nd4jLong* ConstantShapeHelper::emptyShapeInfo(const nd4j::DataType dataType) {
|
||||||
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
|
auto descriptor = ShapeDescriptor::emptyDescriptor(dataType);
|
||||||
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
return bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
|
|
|
@ -29,84 +29,8 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto mean = INPUT_VARIABLE(1);
|
|
||||||
auto variance = INPUT_VARIABLE(2);
|
|
||||||
NDArray *gamma = nullptr;
|
|
||||||
NDArray *beta = nullptr;
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const bool applyScale = (bool)INT_ARG(0);
|
|
||||||
const bool applyOffset = (bool)INT_ARG(1);
|
|
||||||
|
|
||||||
// FIXME: double?
|
|
||||||
const double epsilon = T_ARG(0);
|
|
||||||
|
|
||||||
if(applyScale)
|
|
||||||
gamma = INPUT_VARIABLE(3);
|
|
||||||
if(applyOffset)
|
|
||||||
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
|
||||||
|
|
||||||
std::vector<const NDArray*> inArrs(block.width());
|
|
||||||
for(int i = 0; i < block.width(); ++i)
|
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
|
||||||
|
|
||||||
// check whether all input shapes are mutually broadcastable
|
|
||||||
Nd4jLong* outShapeInfo = nullptr;
|
|
||||||
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
|
|
||||||
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !");
|
|
||||||
|
|
||||||
// normalized output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
|
||||||
|
|
||||||
auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt);
|
|
||||||
if(applyScale)
|
|
||||||
sigmaInvGam *= *gamma;
|
|
||||||
|
|
||||||
NDArray inputMinusMean;
|
|
||||||
if(!input->isSameShape(output) && !mean->isSameShape(output)) {
|
|
||||||
auto inputTiled = NDArray(output, false, block.launchContext());
|
|
||||||
input->tile(inputTiled);
|
|
||||||
inputMinusMean = inputTiled - *mean;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
inputMinusMean = *input - *mean;
|
|
||||||
|
|
||||||
if (applyOffset)
|
|
||||||
output->assign(inputMinusMean * sigmaInvGam + *beta);
|
|
||||||
else
|
|
||||||
output->assign(inputMinusMean * sigmaInvGam);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_TYPES(batchnorm) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_SHAPE_FN(batchnorm) {
|
CUSTOM_OP_IMPL(batchnorm, 3, 1, false, 1, 2) {
|
||||||
|
|
||||||
std::vector<const NDArray*> inArrs(block.width());
|
|
||||||
auto in = inputShape->at(0);
|
|
||||||
for(int i = 0; i < block.width(); ++i)
|
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
|
||||||
|
|
||||||
// check whether all input shapes are mutually broadcastable
|
|
||||||
Nd4jLong* outShapeInfo = nullptr;
|
|
||||||
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
|
|
||||||
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM op: the shapes of input arrays are not mutually broadcastable !");
|
|
||||||
|
|
||||||
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo, DataTypeUtils::pickFloatingType(ArrayOptions::dataType(in))));
|
|
||||||
return SHAPELIST(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto mean = INPUT_VARIABLE(1);
|
auto mean = INPUT_VARIABLE(1);
|
||||||
|
@ -123,7 +47,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
|
||||||
if(applyScale)
|
if(applyScale)
|
||||||
gamma = INPUT_VARIABLE(3);
|
gamma = INPUT_VARIABLE(3);
|
||||||
if(applyOffset)
|
if(applyOffset)
|
||||||
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
beta = INPUT_VARIABLE(3 + (int)applyScale);
|
||||||
|
|
||||||
const int numOfIntArgs = block.getIArguments()->size();
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
|
@ -137,30 +61,31 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
|
||||||
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
|
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
const int numOfAxes = axes.size();
|
const int numOfAxes = axes.size();
|
||||||
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_NEW op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
|
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
|
||||||
|
|
||||||
// get, for example, something like {1, inDim1, 1, inDim3, 1} if axes = {1, 3}
|
|
||||||
std::vector<Nd4jLong> expShapeWithUnities(inRank, 1);
|
|
||||||
for(int i = 0; i < numOfAxes; ++i)
|
|
||||||
expShapeWithUnities[axes[i]] = input->sizeAt(axes[i]);
|
|
||||||
|
|
||||||
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
|
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
|
||||||
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
|
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
|
||||||
std::vector<Nd4jLong> expShape = numOfAxes == 1 ? std::vector<Nd4jLong>(1, input->sizeAt(axes[0])) : expShapeWithUnities;
|
std::vector<Nd4jLong> expShape;
|
||||||
std::string expShapeStr = ShapeUtils::shapeAsString(expShape);
|
if(numOfAxes == 1)
|
||||||
|
expShape.push_back(input->sizeAt(axes[0]));
|
||||||
|
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
|
||||||
|
expShape = std::vector<Nd4jLong>(inRank, 1);
|
||||||
|
for(uint i = 0; i < numOfAxes; ++i)
|
||||||
|
expShape[axes[i]] = input->sizeAt(axes[i]);
|
||||||
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(mean) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of mean array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(mean).c_str());
|
REQUIRE_TRUE(mean->isSameShape(expShape) , 0, "BATCHNORM op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str());
|
||||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(variance) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of variance array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(variance).c_str());
|
REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str());
|
||||||
if(gamma)
|
if(gamma)
|
||||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(gamma) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of gamma array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(gamma).c_str());
|
REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str());
|
||||||
if(beta)
|
if(beta)
|
||||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(beta) == expShapeStr, 0, "BATCHNORM_NEW op: wrong shape of beta array, expected is %s, but got %s instead !", expShapeStr.c_str(), ShapeUtils::shapeAsString(beta).c_str());
|
REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str());
|
||||||
|
|
||||||
// types of all input arrays should be the same
|
// types of all input arrays should be the same
|
||||||
for(int i = 1; i < block.width(); ++i)
|
for(int i = 1; i < block.width(); ++i)
|
||||||
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_NEW op: types of all input arrays should be the same !");
|
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM op: types of all input arrays should be the same !");
|
||||||
|
|
||||||
nd4j_debug("MKL-DNN is not used for batchnorm_new!\n", 0);
|
nd4j_debug("MKL-DNN is not used for batchnorm!\n", 0);
|
||||||
|
|
||||||
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
||||||
helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon);
|
helpers::batchnorm(input, mean, variance, gamma, beta, output, axes, epsilon);
|
||||||
|
@ -168,15 +93,15 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(batchnorm_new) {
|
DECLARE_TYPES(batchnorm) {
|
||||||
getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true);
|
getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS})->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(batchnorm_new) {
|
DECLARE_SHAPE_FN(batchnorm) {
|
||||||
|
|
||||||
auto inShapeInfo = inputShape->at(0);
|
auto inShapeInfo = inputShape->at(0);
|
||||||
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo));
|
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo));
|
||||||
|
|
||||||
auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.getWorkspace()); // output shape is identical to input shape
|
auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(inShapeInfo, outType, false, block.getWorkspace()); // output shape is identical to input shape
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outShapeInfo));
|
return SHAPELIST(CONSTANT(outShapeInfo));
|
||||||
|
@ -184,290 +109,177 @@ DECLARE_SHAPE_FN(batchnorm_new) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto mean = INPUT_VARIABLE(1);
|
|
||||||
auto variance = INPUT_VARIABLE(2);
|
|
||||||
NDArray *gamma = nullptr;
|
|
||||||
NDArray *beta = nullptr;
|
|
||||||
NDArray *dLdO = nullptr; // next epsilon
|
|
||||||
|
|
||||||
auto dLdI = OUTPUT_VARIABLE(0);
|
NDArray* input = INPUT_VARIABLE(0);
|
||||||
auto dLdM = OUTPUT_VARIABLE(1);
|
NDArray* mean = INPUT_VARIABLE(1);
|
||||||
auto dLdV = OUTPUT_VARIABLE(2);
|
NDArray* variance = INPUT_VARIABLE(2);
|
||||||
NDArray *dLdG = nullptr;
|
NDArray* dLdO = INPUT_VARIABLE(3); // next epsilon
|
||||||
NDArray *dLdB = nullptr;
|
NDArray* gamma = nullptr;
|
||||||
|
NDArray* beta = nullptr;
|
||||||
|
|
||||||
const bool applyScale = (bool)INT_ARG(0);
|
|
||||||
const bool applyOffset = (bool)INT_ARG(1);
|
|
||||||
|
|
||||||
// FIXME: double?
|
NDArray* dLdI = OUTPUT_VARIABLE(0);
|
||||||
const double epsilon = T_ARG(0);
|
NDArray* dLdM = OUTPUT_VARIABLE(1);
|
||||||
|
NDArray* dLdV = OUTPUT_VARIABLE(2);
|
||||||
|
NDArray* dLdG = nullptr;
|
||||||
|
NDArray* dLdB = nullptr;
|
||||||
|
|
||||||
const int dLdONum = static_cast<int>(applyScale) + static_cast<int>(applyOffset);
|
const bool applyScale = (bool)INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool)INT_ARG(1);
|
||||||
|
const float epsilon = T_ARG(0);
|
||||||
|
|
||||||
if(applyScale) {
|
if(applyScale) {
|
||||||
gamma = INPUT_VARIABLE(3);
|
gamma = INPUT_VARIABLE(4);
|
||||||
dLdG = OUTPUT_VARIABLE(3);
|
dLdG = OUTPUT_VARIABLE(3);
|
||||||
}
|
}
|
||||||
if(applyOffset) {
|
if(applyOffset) {
|
||||||
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
beta = INPUT_VARIABLE(4 + (int)applyScale);
|
||||||
dLdB = OUTPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
|
||||||
}
|
}
|
||||||
|
|
||||||
dLdO = INPUT_VARIABLE(3 + dLdONum);
|
|
||||||
|
|
||||||
std::vector<const NDArray*> inArrs(block.width());
|
|
||||||
for(int i = 0; i < 4 + dLdONum; ++i)
|
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
|
||||||
|
|
||||||
// check whether all input shapes are mutually broadcastable
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
Nd4jLong* outShapeInfo = nullptr;
|
const int inRank = input->rankOf();
|
||||||
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
|
|
||||||
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !");
|
// get axes args to normalize input array over
|
||||||
|
std::vector<int> axes;
|
||||||
|
if(numOfIntArgs > 2)
|
||||||
|
for(int i = 2; i < numOfIntArgs; ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
|
const int numOfAxes = axes.size();
|
||||||
|
REQUIRE_TRUE(numOfAxes <= inRank, 0, "BATCHNORM_BP op: too big number of input axes to normalize over, expected number should be less or equal to rank of input array, but got %i and %i correspondingly !", numOfAxes, inRank);
|
||||||
|
|
||||||
|
// evaluate expected shape for mean, variance and gamma. These 3 arrays should have identical shapes
|
||||||
|
// for example if input shape is {2,3,4,5,6} and axes = {1,3}, then expected shape would be {1,3,1,5,1}, and if axes = {3}, then expected shape would be {5}
|
||||||
|
std::vector<Nd4jLong> expShape;
|
||||||
|
if(numOfAxes == 1)
|
||||||
|
expShape.push_back(input->sizeAt(axes[0]));
|
||||||
|
else { // get, for example, something like {1, inputDim1, 1, inputDim3, 1} if axes = {1, 3}
|
||||||
|
expShape = std::vector<Nd4jLong>(inRank, 1);
|
||||||
|
for(uint i = 0; i < numOfAxes; ++i)
|
||||||
|
expShape[axes[i]] = input->sizeAt(axes[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(mean->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of mean array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(mean).c_str());
|
||||||
|
REQUIRE_TRUE(variance->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of variance array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(variance).c_str());
|
||||||
|
if(gamma)
|
||||||
|
REQUIRE_TRUE(gamma->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of gamma array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(gamma).c_str());
|
||||||
|
if(beta)
|
||||||
|
REQUIRE_TRUE(beta->isSameShape(expShape), 0, "BATCHNORM_BP op: wrong shape of beta array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expShape).c_str(), ShapeUtils::shapeAsString(beta).c_str());
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP op: wrong shape of output gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str());
|
||||||
|
|
||||||
|
// types of all input arrays should be the same (except dLdO)
|
||||||
|
for(int i = 1; i < block.width() - 1; ++i)
|
||||||
|
if(i != 3)
|
||||||
|
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP op: types of arrays (input, mean, variance, gamma, beta) should be the same !");
|
||||||
|
|
||||||
// ***** calculations ***** //
|
// ***** calculations ***** //
|
||||||
|
|
||||||
auto sigmaInv = (*variance + epsilon).transform(transform::RSqrt);
|
// formula for forward step: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
||||||
|
|
||||||
NDArray sigmaInvGamdLdO = -sigmaInv * *dLdO;
|
|
||||||
if(applyScale)
|
|
||||||
sigmaInvGamdLdO *= *gamma;
|
|
||||||
|
|
||||||
NDArray inputMinusMean;
|
// consider mean and variance as constants (since we get them as inputs and don't calculate them)
|
||||||
if(!input->isSameShape(dLdO) && !mean->isSameShape(dLdO)) {
|
// dLdI = (dLdO * gamma) / (variance + epsilon)^0.5
|
||||||
auto inputTiled = NDArray(dLdO, false, block.launchContext());
|
// dLdV = (-0.5 * gamma * (dLdO * (x - mean))_sum) / (variance + epsilon)^1.5
|
||||||
input->tile(inputTiled);
|
// dLdM = - (dLdO_sum * gamma) / (variance + epsilon)^0.5
|
||||||
inputMinusMean = inputTiled - *mean;
|
// dLdG = (dLdO * (x - mean))_sum / (variance + epsilon)^0.5
|
||||||
}
|
// dLdB = dLdO_sum
|
||||||
else
|
|
||||||
inputMinusMean = *input - *mean;
|
const auto excludedAxes = ShapeUtils::evalDimsToExclude(inRank, axes);
|
||||||
|
|
||||||
|
NDArray temp1 = *variance + epsilon;
|
||||||
|
temp1.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon)
|
||||||
|
auto temp2 = temp1.transform(transform::Sqrt); // 1 / (variance + epsilon)^0.5
|
||||||
|
if(applyScale)
|
||||||
|
temp2 *= *gamma; // gamma / (variance + epsilon)^0.5
|
||||||
|
|
||||||
|
NDArray temp3(input); // empty array with same shape as input
|
||||||
|
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &temp3); // input - mean
|
||||||
|
temp3 *= *dLdO; // (input - mean) * dLdO
|
||||||
|
|
||||||
|
const bool keepUnitiesInShape = inRank == mean->rankOf();
|
||||||
|
|
||||||
// dLdI
|
// dLdI
|
||||||
if(!dLdI->isSameShape(dLdO))
|
dLdO->applyBroadcast(nd4j::broadcast::Multiply, axes, &temp2, dLdI);
|
||||||
dLdI->assign( (-sigmaInvGamdLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdI->getShapeInfo(), dLdO->getShapeInfo())) );
|
|
||||||
else
|
|
||||||
dLdI->assign(-sigmaInvGamdLdO);
|
|
||||||
|
|
||||||
// dLdM
|
// dLdM
|
||||||
if(!dLdM->isSameShape(dLdO))
|
dLdO->reduceAlongDimension(reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); // dLdO sum over excluded axes
|
||||||
dLdM->assign( sigmaInvGamdLdO.reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdM->getShapeInfo(), dLdO->getShapeInfo())) );
|
|
||||||
else
|
|
||||||
dLdM->assign(sigmaInvGamdLdO);
|
|
||||||
|
|
||||||
// dLdV
|
// dLdB
|
||||||
if(!dLdV->isSameShape(dLdO)) {
|
if(applyOffset)
|
||||||
dLdV->assign( (sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdV->getShapeInfo(), dLdO->getShapeInfo())) );
|
dLdB->assign(dLdM);
|
||||||
}
|
|
||||||
else
|
// dLdM
|
||||||
dLdV->assign(sigmaInv * sigmaInv * sigmaInvGamdLdO * inputMinusMean * 0.5f);
|
// dLdM->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
|
||||||
|
// dLdM->applyTransform(nd4j::transform::Neg);
|
||||||
|
*dLdM = 0; // put zeros so far
|
||||||
|
|
||||||
|
//dLdV
|
||||||
|
temp3.reduceAlongDimension(reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); // ((input - mean) * dLdO)_sum
|
||||||
|
|
||||||
// dLdG
|
// dLdG
|
||||||
if(applyScale) {
|
if(applyScale) {
|
||||||
if(!dLdG->isSameShape(dLdO))
|
dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, &temp2, dLdG);
|
||||||
dLdG->assign( (sigmaInv * inputMinusMean * *dLdO).reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdG->getShapeInfo(), dLdO->getShapeInfo())) );
|
// dLdV->assign(dLdG);
|
||||||
else
|
dLdG->applyPairwiseTransform(nd4j::pairwise::Divide, *gamma);
|
||||||
dLdG->assign(sigmaInv * inputMinusMean * *dLdO);
|
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp2);
|
||||||
|
|
||||||
// dLdB
|
// dLdV
|
||||||
if(applyOffset) {
|
// dLdV->applyPairwiseTransform(nd4j::pairwise::Multiply, temp1);
|
||||||
if(!dLdB->isSameShape(dLdO))
|
// *dLdV *= -0.5;
|
||||||
dLdB->assign(dLdO->reduceAlongDims(reduce::Sum, ShapeUtils::evalBroadcastBackwardAxis(dLdB->getShapeInfo(), dLdO->getShapeInfo())) );
|
*dLdV = 0; // put zeros so far
|
||||||
else
|
|
||||||
dLdB->assign(dLdO);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(batchnorm_bp) {
|
DECLARE_TYPES(batchnorm_bp) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(1, nd4j::DataType::ANY)
|
->setAllowedInputTypes(1, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(2, nd4j::DataType::ANY)
|
->setAllowedInputTypes(2, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(3, nd4j::DataType::ANY)
|
->setAllowedInputTypes(3, {ALL_FLOATS})
|
||||||
->setAllowedInputTypes(4, nd4j::DataType::ANY)
|
->setAllowedInputTypes(4, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(5, {ALL_FLOATS})
|
->setAllowedInputTypes(5, nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(batchnorm_bp) {
|
DECLARE_SHAPE_FN(batchnorm_bp) {
|
||||||
|
|
||||||
|
Nd4jLong* inShapeInfo = inputShape->at(0);
|
||||||
|
Nd4jLong* meanShapeInfo = inputShape->at(1);
|
||||||
|
|
||||||
const bool applyScale = (bool)INT_ARG(0);
|
const bool applyScale = (bool)INT_ARG(0);
|
||||||
const bool applyOffset = (bool)INT_ARG(1);
|
const bool applyOffset = (bool)INT_ARG(1);
|
||||||
|
|
||||||
const int dLdONum = static_cast<int>(applyScale) + static_cast<int>(applyOffset);
|
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(inShapeInfo));
|
||||||
|
|
||||||
std::vector<const NDArray*> inArrs(block.width());
|
auto shapes = SHAPELIST();
|
||||||
for(int i = 0; i < 4 + dLdONum; ++i)
|
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
|
||||||
|
|
||||||
// check whether all input shapes are mutually broadcastable
|
// dLdI shapeInfo
|
||||||
Nd4jLong* outShapeInfo = nullptr;
|
shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, inShapeInfo));
|
||||||
const bool areShapesOk = ShapeUtils::evalCommonBroadcastShapeInfo(inArrs, outShapeInfo, block.getWorkspace());
|
|
||||||
REQUIRE_TRUE(areShapesOk, 0, "BATCHNORM_BP op: the shapes of input arrays are not mutually broadcastable !");
|
|
||||||
|
|
||||||
Nd4jLong* dLdIShapeInfo(nullptr), *dLdMShapeInfo(nullptr), *dLdVShapeInfo(nullptr), *dLdGShapeInfo(nullptr), *dLdBShapeInfo(nullptr);
|
// dLdM shapeInfo
|
||||||
COPY_SHAPE(inputShape->at(0), dLdIShapeInfo);
|
shapes->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(outType, meanShapeInfo));
|
||||||
COPY_SHAPE(inputShape->at(1), dLdMShapeInfo);
|
|
||||||
COPY_SHAPE(inputShape->at(2), dLdVShapeInfo);
|
|
||||||
|
|
||||||
if(applyScale) {
|
// dLdV shapeInfo (same as dLdM)
|
||||||
COPY_SHAPE(inputShape->at(3), dLdGShapeInfo);
|
shapes->push_back(shapes->at(shapes->size()-1));
|
||||||
}
|
|
||||||
if(applyOffset){
|
|
||||||
COPY_SHAPE(inputShape->at(3 + static_cast<int>(applyScale)), dLdBShapeInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!applyScale && !applyOffset)
|
// dLdG shapeInfo (same as dLdM)
|
||||||
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo));
|
if(applyScale)
|
||||||
|
shapes->push_back(shapes->at(shapes->size()-1));
|
||||||
|
|
||||||
if(applyScale && !applyOffset)
|
// dLdB shapeInfo (same as dLdM)
|
||||||
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo));
|
if(applyOffset)
|
||||||
|
shapes->push_back(shapes->at(shapes->size()-1));
|
||||||
|
|
||||||
if(!applyScale && applyOffset)
|
return shapes;
|
||||||
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdBShapeInfo));
|
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(dLdIShapeInfo), CONSTANT(dLdMShapeInfo), CONSTANT(dLdVShapeInfo), CONSTANT(dLdGShapeInfo), CONSTANT(dLdBShapeInfo));
|
|
||||||
}
|
}
|
||||||
// //////////////////////////////////////////////////////////////////////////
|
|
||||||
// CONFIGURABLE_OP_IMPL(batchnorm_bp, 5, 1, true, 0, 1) {
|
|
||||||
|
|
||||||
// NDArray<T>* input = INPUT_VARIABLE(0);
|
|
||||||
// NDArray<T>* epsilon = INPUT_VARIABLE(1);
|
|
||||||
// NDArray<T>* gamma = INPUT_VARIABLE(2);
|
|
||||||
// NDArray<T>* dGlobalMeanView = INPUT_VARIABLE(3);
|
|
||||||
// NDArray<T>* dGlobalVarView = INPUT_VARIABLE(4);
|
|
||||||
// NDArray<T>* outEpsilon = this->getZ(block);
|
|
||||||
// std::vector<int> argI = *(block.getIArguments());
|
|
||||||
// const int bS = epsilon->sizeAt(0);
|
|
||||||
// bool isLockGammaBeta = (bool)argI[0];
|
|
||||||
// const int* epsilonShape = epsilon->getShapeInfo() + 1;
|
|
||||||
// const T eps = (T)1e-5;
|
|
||||||
|
|
||||||
// int rank = epsilon->rankOf();
|
|
||||||
// std::initializer_list<int> dimensions;
|
|
||||||
// int effectiveBatchSize;
|
|
||||||
// if (rank == 2) {
|
|
||||||
// dimensions = {0};
|
|
||||||
// effectiveBatchSize = bS;
|
|
||||||
// }
|
|
||||||
// else if (rank == 4) {
|
|
||||||
// dimensions = {0, 2, 3};
|
|
||||||
// effectiveBatchSize = input->sizeAt(0)*input->sizeAt(2)*input->sizeAt(3);
|
|
||||||
// }
|
|
||||||
// else
|
|
||||||
// throw "Graph operation batchnorm_bp: the epsilon rank must be equal to 2 or 4 !";
|
|
||||||
|
|
||||||
// NDArray<T> *mean(nullptr), *var(nullptr), *dBeta(nullptr), *dGamma(nullptr), *dLdVar(nullptr), *dxmu1(nullptr), *dxmu2(nullptr);
|
|
||||||
// mean = input->template reduceAlongDimension<simdOps::Mean<T>>(dimensions);
|
|
||||||
// var = input->template varianceAlongDimension<simdOps::SummaryStatsVariance<T>>(false, dimensions);
|
|
||||||
// var->template applyScalar<simdOps::Add<T>>(eps, nullptr);
|
|
||||||
// auto std = new NDArray<T>(var->getShapeInfo(), block.getWorkspace());
|
|
||||||
// var->template applyTransform<simdOps::Sqrt<T>>(std, nullptr);
|
|
||||||
|
|
||||||
// auto xMu = new NDArray<T>(input->getShapeInfo(), block.getWorkspace());
|
|
||||||
// auto xHat = new NDArray<T>(input->getShapeInfo(), block.getWorkspace());
|
|
||||||
// auto temp1 = new NDArray<T>(epsilon->getShapeInfo(), block.getWorkspace());
|
|
||||||
// auto temp2 = new NDArray<T>(std->getShapeInfo(), block.getWorkspace());
|
|
||||||
// auto dGammaView = new NDArray<T>('c', {1, epsilonShape[1]}, block.getWorkspace());
|
|
||||||
// auto dBetaView = new NDArray<T>('c', {1, epsilonShape[1]}, block.getWorkspace());
|
|
||||||
// auto dxhat = new NDArray<T>(epsilon->getShapeInfo(), block.getWorkspace());
|
|
||||||
|
|
||||||
// if (rank == 2) {
|
|
||||||
// input->subRowVector(mean, xMu);
|
|
||||||
// xMu->divRowVector(std, xHat);
|
|
||||||
// }
|
|
||||||
// else {
|
|
||||||
// input->template applyBroadcast<simdOps::Subtract<T>>({1}, mean, xMu, nullptr);
|
|
||||||
// xMu->template applyBroadcast<simdOps::Divide<T>>({1}, std, xHat, nullptr);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// dBeta = epsilon->sum(dimensions); // dL/dBeta = sum_examples dL/dOut
|
|
||||||
// epsilon->template applyPairwiseTransform<simdOps::Multiply<T>>(xHat, temp1, nullptr); //dL/dGamma = sum_examples dL/dOut .* xHat
|
|
||||||
// dGamma = temp1->sum(dimensions); //dL/dGamma = sum_examples dL/dOut .* xHat
|
|
||||||
|
|
||||||
// if (isLockGammaBeta)
|
|
||||||
// epsilon->template applyPairwiseTransform<simdOps::Multiply<T>>(gamma, dxhat, nullptr);
|
|
||||||
// else {// Standard case
|
|
||||||
// if(rank == 2)
|
|
||||||
// epsilon->mulRowVector(gamma, dxhat); //dL/dxHat = dL/dOut . gamma Shape: [minibatchSize, nOut]
|
|
||||||
// else
|
|
||||||
// epsilon->template applyBroadcast<simdOps::Multiply<T>>({1}, gamma, dxhat, nullptr);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // dLdVar - dL/dVariance, shape: [1, miniBatch]
|
|
||||||
// dxhat->template applyPairwiseTransform<simdOps::Multiply<T>>(xMu, temp1, nullptr);
|
|
||||||
// dLdVar = temp1->sum(dimensions);
|
|
||||||
// dLdVar->template applyScalar<simdOps::Multiply<T>>((T)-0.5, nullptr);
|
|
||||||
// T powParams[] = {(T)(-3.)};
|
|
||||||
// std->template applyTransform<simdOps::Pow<T>>(temp2, powParams);
|
|
||||||
// dLdVar->template applyPairwiseTransform<simdOps::Multiply<T>>(temp2, nullptr);
|
|
||||||
|
|
||||||
// //dL/dmu
|
|
||||||
// dxmu1 = dxhat->sum(dimensions);
|
|
||||||
// dxmu1->template applyPairwiseTransform<simdOps::Divide<T>>(std, nullptr);
|
|
||||||
// dxmu1->template applyTransform<simdOps::Neg<T>>();
|
|
||||||
// dxmu2 = xMu->sum(dimensions);
|
|
||||||
// dxmu2->template applyScalar<simdOps::Multiply<T>>((T)(-2.)/effectiveBatchSize);
|
|
||||||
// dxmu2->template applyPairwiseTransform<simdOps::Multiply<T>>(dLdVar, nullptr);
|
|
||||||
|
|
||||||
// dxmu1->template applyPairwiseTransform<simdOps::Add<T>>(dxmu2, nullptr);
|
|
||||||
// NDArray<T>* dLdmu = dxmu1; // = dL/dmu Shape: [1, nOut]
|
|
||||||
|
|
||||||
// //Note the array reuse here: dxhat, xMu, dLdVar, dLdmu - all are invalid after this line (but aren't used later anyway)
|
|
||||||
// NDArray<T>* dLdx = dxhat;
|
|
||||||
// dLdVar->template applyScalar<simdOps::Multiply<T>>((T)(2.)/effectiveBatchSize);
|
|
||||||
// dLdmu->template applyScalar<simdOps::Multiply<T>>((T)(1.)/effectiveBatchSize);
|
|
||||||
// if(rank == 2) {
|
|
||||||
// dLdx->divRowVector(std, dLdx);
|
|
||||||
// xMu->mulRowVector(dLdVar, xMu);
|
|
||||||
// }
|
|
||||||
// else {
|
|
||||||
// dLdx->template applyBroadcast<simdOps::Divide<T>>({1}, std, dLdx, nullptr);
|
|
||||||
// xMu->template applyBroadcast<simdOps::Multiply<T>>({1}, dLdVar, xMu, nullptr);
|
|
||||||
// }
|
|
||||||
// dLdx->template applyPairwiseTransform<simdOps::Add<T>>(xMu, nullptr);
|
|
||||||
// if(rank == 2)
|
|
||||||
// dLdx->addRowVector(dLdmu, dLdx);
|
|
||||||
// else
|
|
||||||
// dLdx->template applyBroadcast<simdOps::Add<T>>({1}, dLdmu, dLdx, nullptr);
|
|
||||||
|
|
||||||
// *outEpsilon = *dLdx;
|
|
||||||
|
|
||||||
// //TODO rework this to avoid the assign here
|
|
||||||
// // dGammaView->assign(dGamma);
|
|
||||||
// // dBetaView->assign(dBeta);
|
|
||||||
// // dGlobalMeanView->assign((T)0.);
|
|
||||||
// // dGlobalVarView->assign((T)0.);
|
|
||||||
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView);
|
|
||||||
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView);
|
|
||||||
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_MEAN, dGlobalMeanView);
|
|
||||||
// // retGradient.setGradientFor(BatchNormalizationParamInitializer.GLOBAL_VAR, dGlobalVarView);
|
|
||||||
|
|
||||||
// delete std;
|
|
||||||
// delete xMu;
|
|
||||||
// delete xHat;
|
|
||||||
// delete mean;
|
|
||||||
// delete var;
|
|
||||||
// delete dBeta;
|
|
||||||
// delete dGamma;
|
|
||||||
// delete dLdVar;
|
|
||||||
// delete dxmu1;
|
|
||||||
// delete dxmu2;
|
|
||||||
// delete temp1;
|
|
||||||
// delete temp2;
|
|
||||||
// delete dxhat;
|
|
||||||
// delete dGammaView;
|
|
||||||
// delete dBetaView;
|
|
||||||
|
|
||||||
// return ND4J_STATUS_OK;
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,12 +29,12 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_softmax)
|
#if NOT_EXCLUDED(OP_softmax)
|
||||||
DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0);
|
DECLARE_CONFIGURABLE_OP(softmax, 1, 1, true, 0, 0);
|
||||||
DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0);
|
DECLARE_CONFIGURABLE_OP(softmax_bp, 2, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Local response normalization implementation as TF.
|
* Local response normalization implementation as TF.
|
||||||
* input: 4D array
|
* input: 4D array
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
*
|
*
|
||||||
* 0: bias
|
* 0: bias
|
||||||
|
@ -42,8 +42,8 @@ namespace nd4j {
|
||||||
* 2: beta
|
* 2: beta
|
||||||
*
|
*
|
||||||
* Int arg: depth - optional local radius
|
* Int arg: depth - optional local radius
|
||||||
*
|
*
|
||||||
* output - 4D array
|
* output - 4D array
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_lrn)
|
#if NOT_EXCLUDED(OP_lrn)
|
||||||
DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0);
|
DECLARE_CONFIGURABLE_OP(lrn, 1, 1, true, 3, 0);
|
||||||
|
@ -51,10 +51,10 @@ namespace nd4j {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Local response normalization - backprop variant.
|
* Local response normalization - backprop variant.
|
||||||
* input:
|
* input:
|
||||||
* 0 - 4D array of data
|
* 0 - 4D array of data
|
||||||
* 1 - epsilon - 4D array of approximation
|
* 1 - epsilon - 4D array of approximation
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
*
|
*
|
||||||
* 0: bias
|
* 0: bias
|
||||||
|
@ -70,34 +70,31 @@ namespace nd4j {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Batch normalization implementation.
|
* Batch normalization implementation.
|
||||||
* Reference: https://arxiv.org/abs/1502.03167v3
|
* Reference: https://arxiv.org/abs/1502.03167v3
|
||||||
*
|
*
|
||||||
* Expected arguments:
|
* Expected arguments:
|
||||||
* input: input array (any number of dimensions)
|
* input: input array (any number of dimensions)
|
||||||
* mean:
|
* mean:
|
||||||
* variance:
|
* variance:
|
||||||
* gamma:
|
* gamma:
|
||||||
* beta:
|
* beta:
|
||||||
*
|
*
|
||||||
* Int args:
|
* Int args:
|
||||||
* 0: apply scale
|
* 0: apply scale
|
||||||
* 1: apply offset
|
* 1: apply offset
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
* 0: epsilon
|
* 0: epsilon
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_batchnorm)
|
#if NOT_EXCLUDED(OP_batchnorm)
|
||||||
DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2);
|
DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2);
|
||||||
#endif
|
#endif
|
||||||
#if NOT_EXCLUDED(OP_batchnorm_new)
|
|
||||||
DECLARE_CUSTOM_OP(batchnorm_new, 3, 1, false, 1, 2);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* back prop in batch normalization
|
* back prop in batch normalization
|
||||||
*
|
*
|
||||||
* Expected arguments:
|
* Expected arguments:
|
||||||
* input: input array (any number of dimensions)
|
* input: input array (any number of dimensions)
|
||||||
* mean:
|
* mean:
|
||||||
|
@ -105,11 +102,11 @@ namespace nd4j {
|
||||||
* gamma: optional
|
* gamma: optional
|
||||||
* beta: optional
|
* beta: optional
|
||||||
* dLdOut: next epsilon
|
* dLdOut: next epsilon
|
||||||
*
|
*
|
||||||
* Int args:
|
* Int args:
|
||||||
* 0: apply scale
|
* 0: apply scale
|
||||||
* 1: apply offset
|
* 1: apply offset
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
* 0: epsilon
|
* 0: epsilon
|
||||||
*
|
*
|
||||||
|
@ -117,8 +114,8 @@ namespace nd4j {
|
||||||
* dL/dInput
|
* dL/dInput
|
||||||
* dL/dMean
|
* dL/dMean
|
||||||
* dL/dVariance
|
* dL/dVariance
|
||||||
* dL/dGamma
|
* dL/dGamma, optional
|
||||||
* dL/dBeta
|
* dL/dBeta, optional
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_batchnorm)
|
#if NOT_EXCLUDED(OP_batchnorm)
|
||||||
DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2);
|
DECLARE_CUSTOM_OP(batchnorm_bp, 4, 3, false, 1, 2);
|
||||||
|
@ -131,30 +128,30 @@ namespace nd4j {
|
||||||
* x: parameters, any shape
|
* x: parameters, any shape
|
||||||
* y: gradients. same shape as x
|
* y: gradients. same shape as x
|
||||||
* lr: optional, learning rate
|
* lr: optional, learning rate
|
||||||
*
|
*
|
||||||
* T args:
|
* T args:
|
||||||
* 0: optional, learning rate
|
* 0: optional, learning rate
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_apply_sgd)
|
#if NOT_EXCLUDED(OP_apply_sgd)
|
||||||
DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0);
|
DECLARE_CONFIGURABLE_OP(apply_sgd, 2, 1, true, -2, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167.
|
* This operation performs batch normalization of layer, it is based on following article http://arxiv.org/abs/1502.03167.
|
||||||
* Expected arguments:
|
* Expected arguments:
|
||||||
* x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where
|
* x: input 4D array of shape [bS,iH,iW,iD] (data format = NHWC) or [bS,iD,iH,iW] (data format = NCHW), where
|
||||||
* bS - batch size
|
* bS - batch size
|
||||||
* iH - input height
|
* iH - input height
|
||||||
* iW - input width
|
* iW - input width
|
||||||
* iD - input depth (or number of channels)
|
* iD - input depth (or number of channels)
|
||||||
* scale: 1D input array of scale factors, shape [iD]
|
* scale: 1D input array of scale factors, shape [iD]
|
||||||
* offset: 1D input array of offsets (shifts), shape [iD]
|
* offset: 1D input array of offsets (shifts), shape [iD]
|
||||||
* mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
* mean: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
||||||
* variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
* variance: 1D input array of population mean used for inference, shape [iD], this array is required only if isTraining = false
|
||||||
*
|
*
|
||||||
* T input arguments:
|
* T input arguments:
|
||||||
* 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x
|
* 0: epsilon, it is optional argument, default value is 0.001, this is small number to be added to the variance of x
|
||||||
*
|
*
|
||||||
* integer input arguments:
|
* integer input arguments:
|
||||||
* 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW
|
* 0: dataFormat, may have two values: zero -> NHWC, unity -> NCHW
|
||||||
* 1: isTraining, may have two values: zero -> inference, unity -> training
|
* 1: isTraining, may have two values: zero -> inference, unity -> training
|
||||||
|
|
|
@ -32,6 +32,8 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector<int>& axes, const double epsilon) {
|
static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray* variance, const NDArray* gamma, const NDArray* beta, NDArray* output, const std::vector<int>& axes, const double epsilon) {
|
||||||
|
|
||||||
|
// formula: output = gamma * ((input - mean) / sqrt(variance + epsilon)) + beta
|
||||||
|
|
||||||
NDArray sigmaInvGam(mean); // do not copy mean's buffer, take only its shapeInfo
|
NDArray sigmaInvGam(mean); // do not copy mean's buffer, take only its shapeInfo
|
||||||
T eps = epsilon;
|
T eps = epsilon;
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
@ -28,139 +29,679 @@
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
using namespace mkldnn;
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
PLATFORM_IMPL(batchnorm_new) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto mean = INPUT_VARIABLE(1);
|
|
||||||
auto variance = INPUT_VARIABLE(2);
|
|
||||||
NDArray *gamma = nullptr;
|
|
||||||
NDArray *beta = nullptr;
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const bool applyScale = (bool) INT_ARG(0);
|
//////////////////////////////////////////////////////////////////////////
|
||||||
const bool applyOffset = (bool) INT_ARG(1);
|
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
|
||||||
const double epsilon = T_ARG(0);
|
|
||||||
|
|
||||||
if (applyScale)
|
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
|
||||||
gamma = INPUT_VARIABLE(3);
|
// also it gives wrong results for formats nhwc and ndhwc
|
||||||
if (applyOffset)
|
|
||||||
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
|
||||||
|
|
||||||
std::vector<int> axes;
|
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
||||||
if (block.numI() > 2)
|
// mean -> 1D [c]
|
||||||
for (int i = 2; i < block.numI(); ++i)
|
// variance -> 1D [c]
|
||||||
axes.push_back(INT_ARG(i));
|
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
|
||||||
else
|
// z(output) - same shape as x
|
||||||
axes.push_back(input->rankOf() - 1);
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
|
const int xRank = x->rankOf();
|
||||||
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
|
|
||||||
weights({0, 1, 0, 0}).assign(1.0f);
|
|
||||||
weights({1, 2, 0, 0}).assign(0.0f);
|
|
||||||
|
|
||||||
mkldnn_memory_desc_t empty;
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(
|
|
||||||
empty), user_dst_md(empty);
|
|
||||||
|
|
||||||
auto norm_flag = normalization_flags::use_global_stats;
|
// input type
|
||||||
if (applyScale || applyOffset)
|
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
|
||||||
norm_flag |= normalization_flags::use_scale_shift;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
|
// indicate whether gamma or/and beta are given
|
||||||
&batchnorm_src_md, nullptr, &batchnorm_dst_md,
|
auto flags = mkldnn::normalization_flags::use_global_stats;
|
||||||
&user_src_md, nullptr, &user_dst_md, axes[0]);
|
if (weights != nullptr)
|
||||||
|
flags |= mkldnn::normalization_flags::use_scale_shift;
|
||||||
|
|
||||||
auto batchnorm_desc = batch_normalization_forward::desc(prop_kind::forward_inference, batchnorm_src_md, epsilon, norm_flag);
|
mkldnn::memory::dims dims;
|
||||||
|
mkldnn::memory::format_tag format;
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
if(xRank == 2) {
|
||||||
mkldnn::stream stream(engine);
|
dims = {x->sizeAt(0), x->sizeAt(1)};
|
||||||
auto batchnorm_prim_desc = batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
|
format = mkldnn::memory::format_tag::nc;
|
||||||
auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
|
||||||
auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
|
|
||||||
mean->buffer());
|
|
||||||
auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
|
|
||||||
variance->buffer());
|
|
||||||
auto batchnorm_src_memory = user_src_memory;
|
|
||||||
mkldnn::memory m(batchnorm_src_md, engine);
|
|
||||||
if (m.get_desc() != user_src_memory.get_desc()) {
|
|
||||||
batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
|
|
||||||
reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
|
|
||||||
batchnorm_src_memory);
|
|
||||||
}
|
|
||||||
auto batchnorm_dst_memory = user_dst_memory;
|
|
||||||
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
if (applyScale || applyOffset) {
|
|
||||||
if (gamma != nullptr) {
|
|
||||||
weights({0, 1, 0, 0}).assign(gamma);
|
|
||||||
}
|
|
||||||
if (beta != nullptr) {
|
|
||||||
weights({1, 2, 0, 0}).assign(beta);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
|
|
||||||
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
|
||||||
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
|
||||||
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
|
||||||
{MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
|
|
||||||
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
|
||||||
} else {
|
|
||||||
batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
|
||||||
{{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
|
||||||
{MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
|
||||||
{MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
|
||||||
{MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
|
||||||
}
|
|
||||||
if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
|
|
||||||
user_dst_memory);
|
|
||||||
}
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(batchnorm_new) {
|
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
if (::optimalLevel() < 2)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto mean = INPUT_VARIABLE(1);
|
|
||||||
auto variance = INPUT_VARIABLE(2);
|
|
||||||
NDArray *gamma = nullptr;
|
|
||||||
NDArray *beta = nullptr;
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const bool applyScale = (bool) INT_ARG(0);
|
|
||||||
const bool applyOffset = (bool) INT_ARG(1);
|
|
||||||
const double epsilon = T_ARG(0);
|
|
||||||
|
|
||||||
if (applyScale)
|
|
||||||
gamma = INPUT_VARIABLE(3);
|
|
||||||
if (applyOffset)
|
|
||||||
beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
|
||||||
|
|
||||||
std::vector<int> axes;
|
|
||||||
if (block.numI() > 2)
|
|
||||||
for (int i = 2; i < block.numI(); ++i)
|
|
||||||
axes.push_back(INT_ARG(i));
|
|
||||||
else
|
|
||||||
axes.push_back(input->rankOf() - 1);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() &&
|
|
||||||
nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
|
|
||||||
axes.size() == 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
else if(xRank == 4) {
|
||||||
|
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
||||||
|
format = mkldnn::memory::format_tag::nchw;
|
||||||
|
}
|
||||||
|
else { // xRank = 5
|
||||||
|
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
||||||
|
format = mkldnn::memory::format_tag::ncdhw;
|
||||||
|
}
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// x
|
||||||
|
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||||
|
if(xRank > 2) {
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
|
||||||
|
}
|
||||||
|
if(xRank > 4)
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
||||||
|
|
||||||
|
// z, output
|
||||||
|
mkldnn::memory::desc z_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
mkldnn::memory::desc z_user_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
z_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
|
||||||
|
if(xRank > 2) {
|
||||||
|
z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2];
|
||||||
|
z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3];
|
||||||
|
}
|
||||||
|
if(xRank > 4)
|
||||||
|
z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4];
|
||||||
|
|
||||||
|
|
||||||
|
// batchnorm forward description
|
||||||
|
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||||
|
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, mkldnn::memory> args;
|
||||||
|
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory and check whether reorder is required
|
||||||
|
|
||||||
|
// x
|
||||||
|
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
|
||||||
|
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? mkldnn::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// z
|
||||||
|
auto z_user_mem = mkldnn::memory(z_user_md, engine, z->getBuffer());
|
||||||
|
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? mkldnn::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
if (zReorder)
|
||||||
|
mkldnn::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// mean
|
||||||
|
auto mean_mkl_mem = mkldnn::memory(op_ff_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||||
|
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
|
||||||
|
|
||||||
|
// variance
|
||||||
|
auto var_mkl_mem = mkldnn::memory(op_ff_prim_desc.variance_desc(), engine, variance->getBuffer());
|
||||||
|
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
|
||||||
|
|
||||||
|
// gamma and beta (and their gradients) if they are present
|
||||||
|
if(weights != nullptr) {
|
||||||
|
|
||||||
|
auto w_mkl_mem = mkldnn::memory(op_ff_prim_desc.weights_desc(), engine, weights->getBuffer());
|
||||||
|
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
mkldnn::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
mkldnn::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
|
||||||
|
const float epsilon, NDArray* dLdI, NDArray* dLdW) {
|
||||||
|
|
||||||
|
// unfortunately mkl dnn doesn't support any format (mkldnn::memory::format_tag::any)
|
||||||
|
// also it gives wrong results for formats nhwc and ndhwc
|
||||||
|
|
||||||
|
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
||||||
|
// mean -> 1D [c]
|
||||||
|
// variance -> 1D [c]
|
||||||
|
// dLdO - same shape as x
|
||||||
|
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
|
||||||
|
// dLdI - same shape as x
|
||||||
|
// dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta
|
||||||
|
|
||||||
|
const int xRank = x->rankOf();
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// input type
|
||||||
|
mkldnn::memory::data_type type = mkldnn::memory::data_type::f32;
|
||||||
|
|
||||||
|
// indicate whether gamma or/and beta are given
|
||||||
|
auto flags = mkldnn::normalization_flags::use_global_stats;
|
||||||
|
if (weights != nullptr)
|
||||||
|
flags |= mkldnn::normalization_flags::use_scale_shift;
|
||||||
|
|
||||||
|
mkldnn::memory::dims dims;
|
||||||
|
mkldnn::memory::format_tag format;
|
||||||
|
|
||||||
|
if(xRank == 2) {
|
||||||
|
dims = {x->sizeAt(0), x->sizeAt(1)};
|
||||||
|
format = mkldnn::memory::format_tag::nc;
|
||||||
|
}
|
||||||
|
else if(xRank == 4) {
|
||||||
|
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
||||||
|
format = mkldnn::memory::format_tag::nchw;
|
||||||
|
}
|
||||||
|
else { // xRank = 5
|
||||||
|
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
||||||
|
format = mkldnn::memory::format_tag::ncdhw;
|
||||||
|
}
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// x
|
||||||
|
mkldnn::memory::desc x_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
mkldnn::memory::desc x_user_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||||
|
if(xRank > 2) {
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
|
||||||
|
}
|
||||||
|
if(xRank > 4)
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
||||||
|
|
||||||
|
// dLdO
|
||||||
|
mkldnn::memory::desc dLdO_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
mkldnn::memory::desc dLdO_user_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
dLdO_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
|
||||||
|
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
|
||||||
|
if(xRank > 2) {
|
||||||
|
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2];
|
||||||
|
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3];
|
||||||
|
}
|
||||||
|
if(xRank > 4)
|
||||||
|
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
|
||||||
|
|
||||||
|
// dLdI
|
||||||
|
mkldnn::memory::desc dLdI_mkl_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
mkldnn::memory::desc dLdI_user_md = mkldnn::memory::desc(dims, type, format);
|
||||||
|
dLdI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||||
|
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
|
||||||
|
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
|
||||||
|
if(xRank > 2) {
|
||||||
|
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2];
|
||||||
|
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3];
|
||||||
|
}
|
||||||
|
if(xRank > 4)
|
||||||
|
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
|
||||||
|
|
||||||
|
// batchnorm forward description
|
||||||
|
mkldnn::batch_normalization_forward::desc op_ff_desc(mkldnn::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||||
|
mkldnn::batch_normalization_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// batchnorm backprop description
|
||||||
|
mkldnn::batch_normalization_backward::desc op_bp_desc(mkldnn::prop_kind::backward, dLdO_mkl_md, x_mkl_md, epsilon, flags);
|
||||||
|
mkldnn::batch_normalization_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, mkldnn::memory> args;
|
||||||
|
|
||||||
|
mkldnn::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory and check whether reorder is required
|
||||||
|
|
||||||
|
// x
|
||||||
|
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
|
||||||
|
const bool xReorder = op_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? mkldnn::memory(op_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
mkldnn::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// dLdO
|
||||||
|
auto dLdO_user_mem = mkldnn::memory(dLdO_user_md, engine, dLdO->getBuffer());
|
||||||
|
const bool dLdOReorder = op_bp_prim_desc.diff_src_desc() != dLdO_user_mem.get_desc();
|
||||||
|
auto dLdO_mkl_mem = dLdOReorder ? mkldnn::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdO_user_mem;
|
||||||
|
if (dLdOReorder)
|
||||||
|
mkldnn::reorder(dLdO_user_mem, dLdO_mkl_mem).execute(stream, dLdO_user_mem, dLdO_mkl_mem);
|
||||||
|
args[MKLDNN_ARG_DIFF_DST] = dLdO_mkl_mem;
|
||||||
|
|
||||||
|
// mean
|
||||||
|
auto mean_mkl_mem = mkldnn::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||||
|
args[MKLDNN_ARG_MEAN] = mean_mkl_mem;
|
||||||
|
|
||||||
|
// variance
|
||||||
|
auto var_mkl_mem = mkldnn::memory(op_bp_prim_desc.variance_desc(), engine, variance->getBuffer());
|
||||||
|
args[MKLDNN_ARG_VARIANCE] = var_mkl_mem;
|
||||||
|
|
||||||
|
// dLdI
|
||||||
|
auto dLdI_user_mem = mkldnn::memory(dLdI_user_md, engine, dLdI->getBuffer());
|
||||||
|
const bool dLdIReorder = op_bp_prim_desc.diff_dst_desc() != dLdI_user_mem.get_desc();
|
||||||
|
auto dLdI_mkl_mem = dLdIReorder ? mkldnn::memory(op_bp_prim_desc.diff_dst_desc(), engine) : dLdI_user_mem;
|
||||||
|
args[MKLDNN_ARG_DIFF_SRC] = dLdI_mkl_mem;
|
||||||
|
|
||||||
|
// gamma and beta (and their gradients) if they are present
|
||||||
|
if(weights != nullptr) {
|
||||||
|
|
||||||
|
auto w_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, weights->getBuffer());
|
||||||
|
args[MKLDNN_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
auto dLdW_mkl_mem = mkldnn::memory(op_bp_prim_desc.weights_desc(), engine, dLdW->getBuffer());
|
||||||
|
args[MKLDNN_ARG_DIFF_WEIGHTS] = dLdW_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
mkldnn::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (dLdIReorder)
|
||||||
|
mkldnn::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(dLdI_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
PLATFORM_IMPL(batchnorm) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||||
|
auto mean = INPUT_VARIABLE(1); // [c]
|
||||||
|
auto variance = INPUT_VARIABLE(2); // [c]
|
||||||
|
NDArray* gamma = nullptr; // [c]
|
||||||
|
NDArray* beta = nullptr; // [c]
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // same shape as input
|
||||||
|
|
||||||
|
const bool applyScale = (bool)INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool)INT_ARG(1);
|
||||||
|
const double epsilon = T_ARG(0);
|
||||||
|
|
||||||
|
if(applyScale)
|
||||||
|
gamma = INPUT_VARIABLE(3);
|
||||||
|
if(applyOffset)
|
||||||
|
beta = INPUT_VARIABLE(3 + (int)applyScale);
|
||||||
|
|
||||||
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
|
const int inRank = input->rankOf();
|
||||||
|
|
||||||
|
// get axes args to normalize input array over
|
||||||
|
std::vector<int> axes;
|
||||||
|
if(numOfIntArgs > 2)
|
||||||
|
for(int i = 2; i < numOfIntArgs; ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
|
const int numOfAxes = axes.size();
|
||||||
|
REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes);
|
||||||
|
REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank);
|
||||||
|
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str());
|
||||||
|
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str());
|
||||||
|
if(gamma != nullptr)
|
||||||
|
REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str());
|
||||||
|
if(beta != nullptr)
|
||||||
|
REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str());
|
||||||
|
|
||||||
|
// types of all input arrays should be the same (except dLdO)
|
||||||
|
for(int i = 1; i < block.width() - 1; ++i)
|
||||||
|
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_MKLDNN op: types of all input arrays should be the same !");
|
||||||
|
|
||||||
|
|
||||||
|
NDArray *weights = nullptr;
|
||||||
|
|
||||||
|
if(applyScale || applyOffset) {
|
||||||
|
|
||||||
|
weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType());
|
||||||
|
|
||||||
|
if(applyScale)
|
||||||
|
(*weights)({0,1, 0,0}).assign(gamma);
|
||||||
|
else
|
||||||
|
(*weights)({0,1, 0,0}).assign(1);
|
||||||
|
if(applyOffset)
|
||||||
|
(*weights)({1,2, 0,0}).assign(beta);
|
||||||
|
else
|
||||||
|
(*weights)({1,2, 0,0}).assign(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
batchnormMKLDNN(input, mean, variance, weights, epsilon, output);
|
||||||
|
|
||||||
|
delete weights;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(batchnorm) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||||
|
auto mean = INPUT_VARIABLE(1); // [c]
|
||||||
|
auto variance = INPUT_VARIABLE(2); // [c]
|
||||||
|
NDArray* gamma = nullptr; // [c]
|
||||||
|
NDArray* beta = nullptr; // [c]
|
||||||
|
|
||||||
|
auto output = OUTPUT_VARIABLE(0); // same shape as input
|
||||||
|
|
||||||
|
const bool applyScale = (bool)INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool)INT_ARG(1);
|
||||||
|
|
||||||
|
if(applyScale)
|
||||||
|
gamma = INPUT_VARIABLE(3);
|
||||||
|
if(applyOffset)
|
||||||
|
beta = INPUT_VARIABLE(3 + (int)applyScale);
|
||||||
|
|
||||||
|
|
||||||
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
|
std::vector<int> axes;
|
||||||
|
if(numOfIntArgs > 2)
|
||||||
|
for(int i = 2; i < numOfIntArgs; ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
|
DataType inputType = input->dataType();
|
||||||
|
DataType meanType = mean->dataType();
|
||||||
|
DataType varType = variance->dataType();
|
||||||
|
DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32;
|
||||||
|
DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32;
|
||||||
|
DataType outType = output->dataType();
|
||||||
|
|
||||||
|
const int inRank = input->rankOf();
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
|
||||||
|
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
|
||||||
|
gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 && outType == DataType::FLOAT32);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// PLATFORM_IMPL(batchnorm) {
|
||||||
|
|
||||||
|
// auto input = INPUT_VARIABLE(0);
|
||||||
|
// auto mean = INPUT_VARIABLE(1);
|
||||||
|
// auto variance = INPUT_VARIABLE(2);
|
||||||
|
// NDArray *gamma = nullptr;
|
||||||
|
// NDArray *beta = nullptr;
|
||||||
|
|
||||||
|
// auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
// const bool applyScale = (bool) INT_ARG(0);
|
||||||
|
// const bool applyOffset = (bool) INT_ARG(1);
|
||||||
|
// const double epsilon = T_ARG(0);
|
||||||
|
|
||||||
|
// if (applyScale)
|
||||||
|
// gamma = INPUT_VARIABLE(3);
|
||||||
|
// if (applyOffset)
|
||||||
|
// beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
||||||
|
|
||||||
|
// std::vector<int> axes;
|
||||||
|
// if (block.numI() > 2)
|
||||||
|
// for (int i = 2; i < block.numI(); ++i)
|
||||||
|
// axes.push_back(INT_ARG(i));
|
||||||
|
// else
|
||||||
|
// axes.push_back(input->rankOf() - 1);
|
||||||
|
|
||||||
|
// std::vector<Nd4jLong> shape({2, mean->lengthOf()});
|
||||||
|
// NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
|
||||||
|
// weights({0, 1, 0, 0}).assign(1.0f);
|
||||||
|
// weights({1, 2, 0, 0}).assign(0.0f);
|
||||||
|
|
||||||
|
// mkldnn_memory_desc_t empty;
|
||||||
|
// mkldnn::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
|
||||||
|
|
||||||
|
// auto flag = mkldnn::normalization_flags::use_global_stats;
|
||||||
|
// if (applyScale || applyOffset)
|
||||||
|
// flag |= mkldnn::normalization_flags::use_scale_shift;
|
||||||
|
|
||||||
|
// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
|
||||||
|
// &batchnorm_src_md, nullptr, &batchnorm_dst_md,
|
||||||
|
// &user_src_md, nullptr, &user_dst_md, axes[0]);
|
||||||
|
|
||||||
|
// auto batchnorm_desc = mkldnn::batch_normalization_forward::desc(mkldnn::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
|
||||||
|
|
||||||
|
// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
// mkldnn::stream stream(engine);
|
||||||
|
// auto batchnorm_prim_desc = mkldnn::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
|
||||||
|
// auto user_src_memory = mkldnn::memory(user_src_md, engine, input->buffer());
|
||||||
|
// auto user_dst_memory = mkldnn::memory(user_dst_md, engine, output->buffer());
|
||||||
|
// auto batchnorm_mean_memory = mkldnn::memory(batchnorm_prim_desc.mean_desc(), engine,
|
||||||
|
// mean->buffer());
|
||||||
|
// auto batchnorm_variance_memory = mkldnn::memory(batchnorm_prim_desc.variance_desc(), engine,
|
||||||
|
// variance->buffer());
|
||||||
|
// auto batchnorm_src_memory = user_src_memory;
|
||||||
|
// mkldnn::memory m(batchnorm_src_md, engine);
|
||||||
|
// if (m.get_desc() != user_src_memory.get_desc()) {
|
||||||
|
// batchnorm_src_memory = mkldnn::memory(batchnorm_src_md, engine);
|
||||||
|
// mkldnn::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
|
||||||
|
// batchnorm_src_memory);
|
||||||
|
// }
|
||||||
|
// auto batchnorm_dst_memory = user_dst_memory;
|
||||||
|
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
// batchnorm_dst_memory = mkldnn::memory(batchnorm_prim_desc.dst_desc(), engine);
|
||||||
|
// }
|
||||||
|
// if (applyScale || applyOffset) {
|
||||||
|
// if (gamma != nullptr) {
|
||||||
|
// weights({0, 1, 0, 0}).assign(gamma);
|
||||||
|
// }
|
||||||
|
// if (beta != nullptr) {
|
||||||
|
// weights({1, 2, 0, 0}).assign(beta);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// auto batchnorm_weights_memory = mkldnn::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
|
||||||
|
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||||
|
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
||||||
|
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
||||||
|
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
||||||
|
// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
|
||||||
|
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
||||||
|
// } else {
|
||||||
|
// mkldnn::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
|
||||||
|
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
|
||||||
|
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
|
||||||
|
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
|
||||||
|
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
|
||||||
|
// }
|
||||||
|
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
||||||
|
// mkldnn::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
|
||||||
|
// user_dst_memory);
|
||||||
|
// }
|
||||||
|
// stream.wait();
|
||||||
|
|
||||||
|
// return Status::OK();
|
||||||
|
// }
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// PLATFORM_CHECK(batchnorm) {
|
||||||
|
// // we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
// auto input = INPUT_VARIABLE(0);
|
||||||
|
// auto mean = INPUT_VARIABLE(1);
|
||||||
|
// auto variance = INPUT_VARIABLE(2);
|
||||||
|
// NDArray *gamma = nullptr;
|
||||||
|
// NDArray *beta = nullptr;
|
||||||
|
|
||||||
|
// auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
// const bool applyScale = (bool) INT_ARG(0);
|
||||||
|
// const bool applyOffset = (bool) INT_ARG(1);
|
||||||
|
// const double epsilon = T_ARG(0);
|
||||||
|
|
||||||
|
// if (applyScale)
|
||||||
|
// gamma = INPUT_VARIABLE(3);
|
||||||
|
// if (applyOffset)
|
||||||
|
// beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
|
||||||
|
|
||||||
|
// std::vector<int> axes;
|
||||||
|
// if (block.numI() > 2)
|
||||||
|
// for (int i = 2; i < block.numI(); ++i)
|
||||||
|
// axes.push_back(INT_ARG(i));
|
||||||
|
// else
|
||||||
|
// axes.push_back(input->rankOf() - 1);
|
||||||
|
|
||||||
|
// return block.isUseMKLDNN() &&
|
||||||
|
// nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
|
||||||
|
// axes.size() == 1;
|
||||||
|
// }
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(batchnorm_bp) {
|
||||||
|
|
||||||
|
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||||
|
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
||||||
|
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
||||||
|
NDArray* dLdO = INPUT_VARIABLE(3); // same as input
|
||||||
|
NDArray* gamma = nullptr; // [c]
|
||||||
|
NDArray* beta = nullptr; // [c]
|
||||||
|
|
||||||
|
NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input
|
||||||
|
NDArray* dLdM = OUTPUT_VARIABLE(1); // [c]
|
||||||
|
NDArray* dLdV = OUTPUT_VARIABLE(2); // [c]
|
||||||
|
NDArray* dLdG = nullptr; // [c]
|
||||||
|
NDArray* dLdB = nullptr; // [c]
|
||||||
|
|
||||||
|
const bool applyScale = (bool)INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool)INT_ARG(1);
|
||||||
|
const float epsilon = T_ARG(0);
|
||||||
|
|
||||||
|
if(applyScale) {
|
||||||
|
gamma = INPUT_VARIABLE(4);
|
||||||
|
dLdG = OUTPUT_VARIABLE(3);
|
||||||
|
}
|
||||||
|
if(applyOffset) {
|
||||||
|
beta = INPUT_VARIABLE(4 + (int)applyScale);
|
||||||
|
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
|
const int inRank = input->rankOf();
|
||||||
|
|
||||||
|
// get axes args to normalize input array over
|
||||||
|
std::vector<int> axes;
|
||||||
|
if(numOfIntArgs > 2)
|
||||||
|
for(int i = 2; i < numOfIntArgs; ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(inRank-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
|
const int numOfAxes = axes.size();
|
||||||
|
REQUIRE_TRUE(numOfAxes == 1, 0, "BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead!", numOfAxes);
|
||||||
|
REQUIRE_TRUE(inRank == 2 || inRank == 4 || inRank == 5, 0, "BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead!", inRank);
|
||||||
|
REQUIRE_TRUE(input->isSameShape(dLdO), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(dLdO).c_str());
|
||||||
|
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(mean).c_str());
|
||||||
|
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(variance).c_str());
|
||||||
|
if(gamma != nullptr)
|
||||||
|
REQUIRE_TRUE(gamma->rankOf() == 1 && gamma->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(gamma).c_str());
|
||||||
|
if(beta != nullptr)
|
||||||
|
REQUIRE_TRUE(beta->rankOf() == 1 && beta->sizeAt(0) == input->sizeAt(axes[0]), 0, "BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead !", input->sizeAt(axes[0]), ShapeUtils::shapeAsString(beta).c_str());
|
||||||
|
|
||||||
|
// types of all input arrays should be the same (except dLdO)
|
||||||
|
for(int i = 1; i < block.width() - 1; ++i)
|
||||||
|
REQUIRE_TRUE(INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType(), 0, "BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same !");
|
||||||
|
|
||||||
|
|
||||||
|
NDArray *weights = nullptr, *dLdW = nullptr;
|
||||||
|
|
||||||
|
if(applyScale || applyOffset) {
|
||||||
|
weights = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType());
|
||||||
|
dLdW = new NDArray(input->ordering(), {2, input->sizeAt(axes[0])}, input->dataType());
|
||||||
|
if(applyScale)
|
||||||
|
(*weights)({0,1, 0,0}).assign(gamma);
|
||||||
|
else
|
||||||
|
(*weights)({0,1, 0,0}).assign(1);
|
||||||
|
if(applyOffset)
|
||||||
|
(*weights)({1,2, 0,0}).assign(beta);
|
||||||
|
else
|
||||||
|
(*weights)({1,2, 0,0}).assign(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
*dLdM = 0;
|
||||||
|
*dLdV = 0;
|
||||||
|
|
||||||
|
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW);
|
||||||
|
|
||||||
|
if(applyScale || applyOffset) {
|
||||||
|
if(applyScale)
|
||||||
|
dLdG->assign((*dLdW)({0,1, 0,0}));
|
||||||
|
if(applyOffset)
|
||||||
|
dLdB->assign((*dLdW)({1,2, 0,0}));
|
||||||
|
|
||||||
|
delete weights;
|
||||||
|
delete dLdW;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(batchnorm_bp) {
|
||||||
|
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||||
|
// if (::optimalLevel() < 2)
|
||||||
|
// return false;
|
||||||
|
|
||||||
|
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||||
|
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
||||||
|
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
||||||
|
NDArray* dLdO = INPUT_VARIABLE(3); // same as input
|
||||||
|
NDArray* gamma = nullptr; // [c]
|
||||||
|
NDArray* beta = nullptr; // [c]
|
||||||
|
|
||||||
|
NDArray* dLdI = OUTPUT_VARIABLE(0); // same as input
|
||||||
|
NDArray* dLdM = OUTPUT_VARIABLE(1); // [c]
|
||||||
|
NDArray* dLdV = OUTPUT_VARIABLE(2); // [c]
|
||||||
|
NDArray* dLdG = nullptr; // [c]
|
||||||
|
NDArray* dLdB = nullptr; // [c]
|
||||||
|
|
||||||
|
const bool applyScale = (bool)INT_ARG(0);
|
||||||
|
const bool applyOffset = (bool)INT_ARG(1);
|
||||||
|
|
||||||
|
if(applyScale) {
|
||||||
|
gamma = INPUT_VARIABLE(4);
|
||||||
|
dLdG = OUTPUT_VARIABLE(3);
|
||||||
|
}
|
||||||
|
if(applyOffset) {
|
||||||
|
beta = INPUT_VARIABLE(4 + (int)applyScale);
|
||||||
|
dLdB = OUTPUT_VARIABLE(3 + (int)applyScale);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int numOfIntArgs = block.getIArguments()->size();
|
||||||
|
std::vector<int> axes;
|
||||||
|
if(numOfIntArgs > 2)
|
||||||
|
for(int i = 2; i < numOfIntArgs; ++i)
|
||||||
|
axes.push_back(INT_ARG(i));
|
||||||
|
else
|
||||||
|
axes.push_back(input->rankOf()-1); // default dimension to reduce along is last dimension
|
||||||
|
|
||||||
|
DataType inputType = input->dataType();
|
||||||
|
DataType meanType = mean->dataType();
|
||||||
|
DataType varType = variance->dataType();
|
||||||
|
DataType dLdOType = dLdO->dataType();
|
||||||
|
DataType gammaType = gamma != nullptr ? gamma->dataType() : DataType::FLOAT32;
|
||||||
|
DataType betaType = beta != nullptr ? beta->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
|
DataType dLdIType = dLdI->dataType();
|
||||||
|
DataType dLdGType = gamma != nullptr ? dLdG->dataType() : DataType::FLOAT32;
|
||||||
|
DataType dLdBType = beta != nullptr ? dLdB->dataType() : DataType::FLOAT32;
|
||||||
|
|
||||||
|
const int inRank = input->rankOf();
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && axes.size() == 1 && axes[0] == 1 && (inRank == 2 || inRank == 4 || inRank == 5) &&
|
||||||
|
(inputType == DataType::FLOAT32 && meanType == DataType::FLOAT32 && varType == DataType::FLOAT32 &&
|
||||||
|
dLdOType == DataType::FLOAT32 && gammaType == DataType::FLOAT32 && betaType == DataType::FLOAT32 &&
|
||||||
|
dLdIType == DataType::FLOAT32 && dLdGType == DataType::FLOAT32 && dLdBType == DataType::FLOAT32);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -132,8 +132,6 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
mkldnn_memory_desc_t empty;
|
|
||||||
|
|
||||||
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
|
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
|
||||||
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;
|
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;
|
||||||
|
|
||||||
|
|
|
@ -305,50 +305,50 @@ namespace nd4j {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
|
// mkldnn::memory::desc* batchnorm_src_md, mkldnn::memory::desc* batchnorm_diff_src_md, mkldnn::memory::desc* batchnorm_dst_md,
|
||||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
// mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, int axis) {
|
||||||
const Nd4jLong* shape = src->getShapeInfo();
|
// const Nd4jLong* shape = src->getShapeInfo();
|
||||||
Nd4jLong rank = shape[0];
|
// Nd4jLong rank = shape[0];
|
||||||
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||||
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||||
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||||
mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
// mkldnn::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||||
|
|
||||||
auto type = mkldnn::memory::data_type::f32;
|
// auto type = mkldnn::memory::data_type::f32;
|
||||||
auto format = mkldnn::memory::format_tag::nchw;
|
// auto format = mkldnn::memory::format_tag::nchw;
|
||||||
auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
|
// auto supposed_to_be_any_format = mkldnn::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||||
*batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
// *batchnorm_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
*user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
// *user_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
user_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
// user_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||||
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||||
}
|
// }
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||||
*batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
// *batchnorm_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
*user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
// *user_diff_src_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
// user_diff_src_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||||
}
|
// }
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||||
*batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
// *batchnorm_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
*user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
// *user_dst_md = mkldnn::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
|
// user_dst_md->data.format_kind = mkldnn_blocked; // overrides format
|
||||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||||
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
|
|
||||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
|
|
@ -62,7 +62,9 @@ namespace nd4j{
|
||||||
|
|
||||||
DECLARE_PLATFORM(lrn);
|
DECLARE_PLATFORM(lrn);
|
||||||
|
|
||||||
DECLARE_PLATFORM(batchnorm_new);
|
DECLARE_PLATFORM(batchnorm);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(batchnorm_bp);
|
||||||
|
|
||||||
DECLARE_PLATFORM(lstmLayer);
|
DECLARE_PLATFORM(lstmLayer);
|
||||||
}
|
}
|
||||||
|
|
|
@ -413,7 +413,7 @@ namespace nd4j {
|
||||||
return ctx;
|
return ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
nd4j::ops::batchnorm_new batchnorm;
|
nd4j::ops::batchnorm batchnorm;
|
||||||
DeclarableBenchmark benchmark(batchnorm, "batchnorm");
|
DeclarableBenchmark benchmark(batchnorm, "batchnorm");
|
||||||
output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization");
|
output += helper.runOperationSuit(&benchmark, generator, batch, "Batch Normalization");
|
||||||
|
|
||||||
|
@ -1822,7 +1822,7 @@ namespace nd4j {
|
||||||
std::string result;
|
std::string result;
|
||||||
|
|
||||||
long start = nowMs();
|
long start = nowMs();
|
||||||
|
|
||||||
// set 1
|
// set 1
|
||||||
nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", "");
|
nd4j_printf("Running FullBenchmarkSuite.fastScalarBenchmark\n", "");
|
||||||
result += fastScalarBenchmark();
|
result += fastScalarBenchmark();
|
||||||
|
|
|
@ -2385,129 +2385,6 @@ TEST_F(DeclarableOpsTests1, CompactLaunchTests2) {
|
||||||
ASSERT_TRUE(exp.equalsTo(&z));
|
ASSERT_TRUE(exp.equalsTo(&z));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, batchnorm_test1) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
|
||||||
auto mean = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
|
||||||
auto variance = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
|
||||||
auto gamma = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
|
||||||
auto beta = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
|
||||||
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011,10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369});
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
mean.assign(1.);
|
|
||||||
variance.assign(0.5);
|
|
||||||
gamma.assign(1.2);
|
|
||||||
beta.assign(1.);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(output));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, batchnorm_test2) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2,3,1,3,1});
|
|
||||||
auto mean = NDArrayFactory::create<double>('c', {1,3,2,1,2});
|
|
||||||
auto variance = NDArrayFactory::create<double>('c', {2,1,2,3,2});
|
|
||||||
auto gamma = NDArrayFactory::create<double>('c', {2,3,2,3,1});
|
|
||||||
auto beta = NDArrayFactory::create<double>('c', {1,3,2,1,2});
|
|
||||||
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.52733537,-0.52733537,-0.35763144,-0.35763144,-0.18792751,-0.18792751, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, -0.01822358,-0.01822358, 0.15148035, 0.15148035, 0.32118428, 0.32118428, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 0.49088821, 0.49088821, 0.66059214, 0.66059214, 0.83029607, 0.83029607, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1. , 1. , 1.16970393, 1.16970393, 1.33940786, 1.33940786, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 1.50911179, 1.50911179, 1.67881572, 1.67881572, 1.84851965, 1.84851965, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144, 2.01822358, 2.01822358, 2.18792751, 2.18792751, 2.35763144, 2.35763144});
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
mean.assign(1.);
|
|
||||||
variance.assign(0.5);
|
|
||||||
gamma.assign(1.2);
|
|
||||||
beta.assign(1.);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(output));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, batchnorm_test3) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
|
||||||
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
|
|
||||||
auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,1});
|
|
||||||
auto gamma = NDArrayFactory::create<double>('c', {1,1});
|
|
||||||
auto beta = NDArrayFactory::create<double>('c', {1,2});
|
|
||||||
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1., 1.16970393, 1.33940786, 1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503, 3.54555896, 3.71526289, 3.88496682, 4.05467075, 4.22437468, 4.39407861, 4.56378254, 4.73348647, 4.9031904 , 5.07289433, 5.24259826, 5.41230219, 5.58200612, 5.75171005, 5.92141398, 6.09111791, 6.26082184, 6.43052577, 6.6002297 , 6.76993364, 6.93963757, 7.1093415 , 7.27904543, 7.44874936, 7.61845329, 7.78815722, 7.95786115, 8.12756508, 8.29726901, 8.46697294, 8.63667687, 8.8063808 , 8.97608473, 9.14578866, 9.31549259, 9.48519652, 9.65490045, 9.82460438, 9.99430831,10.16401224,10.33371617,10.50342011, 10.67312404,10.84282797,11.0125319 ,11.18223583,11.35193976,11.52164369});
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
mean.assign(1.);
|
|
||||||
variance.assign(0.5);
|
|
||||||
gamma.assign(1.2);
|
|
||||||
beta.assign(1.);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(output));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, batchnorm_test4) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {3,2});
|
|
||||||
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
|
|
||||||
auto variance= NDArrayFactory::create<double>('c', {2,3,1,3,2});
|
|
||||||
auto gamma = NDArrayFactory::create<double>('c', {1,1});
|
|
||||||
auto beta = NDArrayFactory::create<double>('c', {1,2});
|
|
||||||
|
|
||||||
auto expected= NDArrayFactory::create<double>('c', {2,3,2,3,2}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, -0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428});
|
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
|
||||||
mean.assign(1.);
|
|
||||||
variance.assign(0.5);
|
|
||||||
gamma.assign(1.2);
|
|
||||||
beta.assign(1.);
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto output = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(output));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
// TEST_F(DeclarableOpsTests1, sru_old_test1) {
|
// TEST_F(DeclarableOpsTests1, sru_old_test1) {
|
||||||
|
|
|
@ -2313,7 +2313,35 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto output = results->at(0);
|
||||||
|
// output->printBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShapeStrict(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
|
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
|
||||||
auto mean = NDArrayFactory::create<TypeParam>('c', {4});
|
auto mean = NDArrayFactory::create<TypeParam>('c', {4});
|
||||||
|
@ -2330,7 +2358,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
|
||||||
gamma.assign(1.2);
|
gamma.assign(1.2);
|
||||||
beta.assign(1.);
|
beta.assign(1.);
|
||||||
|
|
||||||
nd4j::ops::batchnorm_new op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
||||||
|
|
||||||
|
@ -2346,7 +2374,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) {
|
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
|
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
|
||||||
auto mean = NDArrayFactory::create<TypeParam>('c', {3}, {1.05, 1.1, 1.15});
|
auto mean = NDArrayFactory::create<TypeParam>('c', {3}, {1.05, 1.1, 1.15});
|
||||||
|
@ -2359,7 +2387,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) {
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::batchnorm_new op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
@ -2374,7 +2402,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) {
|
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
|
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
|
||||||
auto mean = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4});
|
auto mean = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4});
|
||||||
|
@ -2387,7 +2415,7 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) {
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::batchnorm_new op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2});
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2});
|
||||||
|
|
||||||
|
@ -2401,6 +2429,63 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_new_test3) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, batchnorm_test5) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20.,
|
||||||
|
-19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438,
|
||||||
|
-12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32);
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto output = results->at(0);
|
||||||
|
// output->printBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShapeStrict(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, batchnorm_test6) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 ,
|
||||||
|
20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 ,
|
||||||
|
-12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32);
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShapeStrict(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) {
|
TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) {
|
||||||
|
|
||||||
|
|
|
@ -2883,78 +2883,336 @@ TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test1) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {3,2});
|
NDArray input ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
|
NDArray mean ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,2});
|
NDArray variance('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
auto gamma = NDArrayFactory::create<double>('c', {1,1});
|
NDArray gamma ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
auto beta = NDArrayFactory::create<double>('c', {1,2});
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
auto dLdO = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
NDArray gradO ('c', {2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.509112, -0.254556, 0., 0.254556,0.509112, 0.763668, 1.018224, 1.272779,
|
||||||
|
1.527335, 1.781891, 2.036447, 2.291003,2.545559, 2.800115, 3.054671, 3.309227,3.563783, 3.818338, 4.072894, 4.32745}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {6.448749, 7.212417, 8.230641, 9.50342 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
mean.assign(1.);
|
mean.assign(1.);
|
||||||
variance.assign(0.5);
|
variance.assign(0.5);
|
||||||
gamma.assign(1.2);
|
gamma.assign(1.2);
|
||||||
beta.assign(1.);
|
// beta.assign(1.); // has no effect on gradient calculations
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1});
|
nd4j::ops::batchnorm_bp op;
|
||||||
const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &beta, &dLdO}, {1e-5}, {1,1});
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm opFF;
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
|
||||||
nd4j::ops::batchnorm_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test2) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
auto mean = NDArrayFactory::create<double>('c', {2,3,2});
|
NDArray mean ('c', {3}, {1.05, 1.1, 1.15});
|
||||||
auto variance = NDArrayFactory::create<double>('c', {2,3,1,3,1});
|
NDArray variance('c', {3}, {0.5, 0.6, 0.7});
|
||||||
auto gamma = NDArrayFactory::create<double>('c', {1,1});
|
NDArray gamma ('c', {3}, {1.2, 1.3, 1.4});
|
||||||
auto dLdO = NDArrayFactory::create<double>('c', {2,3,2,3,2});
|
NDArray beta ('c', {3}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.272779, -1.018224, -0.763668,-0.503484, -0.251742, 0., 0.251742,0.501992, 0.752989, 1.003985, 1.254981,
|
||||||
|
1.527335, 1.781891, 2.036447, 2.291003,2.517418, 2.76916 , 3.020902, 3.272644,3.513947, 3.764943, 4.015939, 4.266936});
|
||||||
|
NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388});
|
||||||
|
NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4});
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
mean.assign(1.);
|
// beta.assign(1.); // has no effect on gradient calculations
|
||||||
variance.assign(0.5);
|
gradO.linspace(-0.9, 0.15);
|
||||||
gamma.assign(1.2);
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&input, &mean, &variance, &gamma}, {1e-5}, {1,0});
|
nd4j::ops::batchnorm_bp op;
|
||||||
const OpArgsHolder argsHolderBP({&input, &mean, &variance, &gamma, &dLdO}, {1e-5}, {1,0});
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm opFF;
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||||
nd4j::ops::batchnorm_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2,3,1,3});
|
NDArray input ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
auto mean = NDArrayFactory::create<double>('c', {1,3,2,1});
|
NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4});
|
||||||
auto variance = NDArrayFactory::create<double>('c', {2,1,2,3});
|
NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2});
|
||||||
auto dLdO = NDArrayFactory::create<double>('c', {2,3,2,3});
|
NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9});
|
||||||
|
NDArray beta ('c', {2,1,4}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray gradO ('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,3,4}, {-1.527335, -1.258709, -1.003985, -0.754668,-0.509112, -0.251742, 0., 0.251556,0.509112, 0.755225, 1.003985, 1.25778 ,
|
||||||
|
1.517885, 1.784991, 2.05947 , 2.341504,2.529808, 2.804986, 3.089205, 3.382173,3.541731, 3.824981, 4.11894 , 4.422841});
|
||||||
|
NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 });
|
||||||
|
NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85});
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
mean.assign(1.);
|
// beta.assign(1.); // has no effect on gradient calculations
|
||||||
variance.assign(0.5);
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&input, &mean, &variance}, {1e-5}, {0,0});
|
nd4j::ops::batchnorm_bp op;
|
||||||
const OpArgsHolder argsHolderBP({&input, &mean, &variance, &dLdO}, {1e-5}, {0,0});
|
|
||||||
|
|
||||||
nd4j::ops::batchnorm opFF;
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,0,2});
|
||||||
nd4j::ops::batchnorm_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test4) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {1.442483, 0.9502 , 0.569207, 0.314641}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test5) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,4,2,2}, {1.527335, 1.272779,1.018224, 0.763668,-0.466136, -0.233068,0., 0.233068,-0.442716, -0.664075,-0.885433, -1.106791,1.287169, 1.501697,1.716225, 1.930753,
|
||||||
|
-2.545559, -2.800115,-3.054671, -3.309227,3.262951, 3.496019,3.729087, 3.962155,-3.984448, -4.205806,-4.427164, -4.648522,4.719618, 4.934146,5.148675, 5.363203}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test6) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584, 0.509112, -0.233068, -0., 0.214528, -0.509112, 0.699204, -0.885433, 1.072641, -1.527335, 1.631475, -1.770866, 1.930753,
|
||||||
|
-2.545559, 2.563747, -2.656298, 2.788865, -3.563783, 3.496019, -3.541731, 3.646978, -4.582006, 4.42829 , -4.427164, 4.50509 , -5.60023 , 5.360562, -5.312597, 5.363203}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,3});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test7) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,2,2,2,4}, {1.527335, -1.16534 , 0.885433, -0.643584,0.509112, -0.233068, -0., 0.214528,-0.509112, 0.699204, -0.885433, 1.072641,-1.527335, 1.631475, -1.770866,
|
||||||
|
1.930753,-2.545559, 2.563747, -2.656298, 2.788865,-3.563783, 3.496019, -3.541731, 3.646978,-4.582006, 4.42829 , -4.427164,
|
||||||
|
4.50509 ,-5.60023 , 5.360562, -5.312597, 5.363203, -6.618453, 6.292834, -6.19803 , 6.221315,-7.636677, 7.225105, -7.083463,
|
||||||
|
7.079428,-8.6549 , 8.157377, -7.968895, 7.93754 ,-9.673124, 9.089649, -8.854328, 8.795652, -10.691348, 10.02192 , -9.739761,
|
||||||
|
9.653765,-11.709571, 10.954192, -10.625194, 10.511877,-12.727795, 11.886464, -11.510627, 11.36999 ,-13.746018, 12.818735, -12.39606 , 12.228102}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,4});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
// dLdI->printBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, batchnorm_bp_test8) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray beta ('c', {4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray gradO ('c', {2,4,2,2,2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expdLdI('c', {2,4,2,2,2}, {1.527335, 1.272779, 1.018224, 0.763668, 0.509112, 0.254556, -0. , -0.254556, 0.466136, 0.699204, 0.932272, 1.16534 , 1.398407, 1.631475, 1.864543, 2.097611,
|
||||||
|
-2.213582, -2.43494 , -2.656298, -2.877657, -3.099015, -3.320373, -3.541731, -3.76309 , 3.861506, 4.076034, 4.290562, 4.50509 , 4.719618, 4.934146, 5.148675, 5.363203,
|
||||||
|
-6.618453, -6.873009, -7.127565, -7.382121, -7.636677, -7.891233, -8.145789, -8.400345, 7.924309, 8.157377, 8.390445, 8.623513, 8.856581, 9.089649, 9.322717, 9.555784,
|
||||||
|
-9.297045, -9.518403, -9.739761, -9.961119, -10.182477, -10.403836, -10.625194, -10.846552, 10.726405, 10.940933, 11.155462, 11.36999 , 11.584518, 11.799046, 12.013574, 12.228102}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(0.1, 0.1);
|
||||||
|
gradO.linspace(-0.9, 0.15);
|
||||||
|
|
||||||
|
nd4j::ops::batchnorm_bp op;
|
||||||
|
|
||||||
|
auto results = op.execute({&input, &mean, &variance, &gradO, &gamma, &beta}, {1e-5}, {1,1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto dLdI = results->at(0);
|
||||||
|
auto dLdG = results->at(3);
|
||||||
|
auto dLdB = results->at(4);
|
||||||
|
|
||||||
|
// dLdI->printBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdI.isSameShapeStrict(dLdI));
|
||||||
|
ASSERT_TRUE(expdLdI.equalsTo(dLdI));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdG.isSameShapeStrict(dLdG));
|
||||||
|
ASSERT_TRUE(expdLdG.equalsTo(dLdG));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expdLdB.isSameShapeStrict(dLdB));
|
||||||
|
ASSERT_TRUE(expdLdB.equalsTo(dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
/*
|
/*
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ TEST_F(MklDnnTests, helpers_includer) {
|
||||||
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp;
|
nd4j::ops::platforms::PLATFORM_maxpool3dnew_bp maxpool3d_bp;
|
||||||
|
|
||||||
nd4j::ops::platforms::PLATFORM_lrn lrn;
|
nd4j::ops::platforms::PLATFORM_lrn lrn;
|
||||||
nd4j::ops::platforms::PLATFORM_batchnorm_new batchnorm;
|
nd4j::ops::platforms::PLATFORM_batchnorm batchnorm;
|
||||||
|
|
||||||
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm});
|
printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm});
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -142,7 +142,7 @@ public class BatchNorm extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "batchnorm_new";
|
return "batchnorm";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue