Fix for certain non-ews cases (#402)
* BtS/StB/StD/DtS dup for views Signed-off-by: raver119 <raver119@gmail.com> * batchnorm_bp dup for views Signed-off-by: raver119 <raver119@gmail.com> * two java tests for bad strides Signed-off-by: raver119 <raver119@gmail.com>master
parent
5ee37a22eb
commit
a10fd4524a
|
@ -75,7 +75,10 @@ CUSTOM_OP_IMPL(batch_to_space, 2, 1, false, 0, 1) {
|
||||||
REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !");
|
REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !");
|
||||||
REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !");
|
REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !");
|
||||||
|
|
||||||
helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize);
|
if (shape::strideDescendingCAscendingF(input->shapeInfo()))
|
||||||
|
helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize);
|
||||||
|
else
|
||||||
|
helpers::batchToSpace(block.launchContext(), input->dup(), *output, cropBottom, cropTop, cropLeft, cropRight, blockSize);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,10 @@ CUSTOM_OP_IMPL(batch_to_space_nd, 3, 1, false, 0, 0) {
|
||||||
REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !");
|
REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !");
|
||||||
}
|
}
|
||||||
|
|
||||||
helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output);
|
if (shape::strideDescendingCAscendingF(input->shapeInfo()))
|
||||||
|
helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output);
|
||||||
|
else
|
||||||
|
helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, *crop, *output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,10 @@ namespace ops {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
helpers::_depthToSpace(block.launchContext(), input, output, block_size, isNHWC);
|
if (shape::strideDescendingCAscendingF(input->shapeInfo()))
|
||||||
|
helpers::_depthToSpace(block.launchContext(), *input, output, block_size, isNHWC);
|
||||||
|
else
|
||||||
|
helpers::_depthToSpace(block.launchContext(), input->dup(), output, block_size, isNHWC);
|
||||||
|
|
||||||
STORE_RESULT(output);
|
STORE_RESULT(output);
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
|
||||||
const uint blockSize = INT_ARG(0);
|
const uint blockSize = INT_ARG(0);
|
||||||
REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize);
|
REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize);
|
||||||
|
|
||||||
|
@ -52,7 +53,10 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !");
|
REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !");
|
||||||
|
|
||||||
helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize);
|
if (shape::strideDescendingCAscendingF(input->shapeInfo()))
|
||||||
|
helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize);
|
||||||
|
else
|
||||||
|
helpers::spaceToBatch(block.launchContext(), input->dup(), *output, padBottom, padTop, padLeft, padRight, blockSize);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,10 @@ CUSTOM_OP_IMPL(space_to_batch_nd, 3, 1, false, 0, 0) {
|
||||||
REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !");
|
REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !");
|
||||||
}
|
}
|
||||||
|
|
||||||
helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output);
|
if (shape::strideDescendingCAscendingF(input->shapeInfo()))
|
||||||
|
helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output);
|
||||||
|
else
|
||||||
|
helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, *padding, *output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,10 @@ namespace ops {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
helpers::_spaceTodepth(block.launchContext(), input, output, block_size, isNHWC);
|
if (shape::strideDescendingCAscendingF(input->shapeInfo()))
|
||||||
|
helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, isNHWC);
|
||||||
|
else
|
||||||
|
helpers::_spaceTodepth(block.launchContext(), input->dup(), output, block_size, isNHWC);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,14 +26,14 @@ namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void __depthToSpace(NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
static void __depthToSpace(const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
T *input_ptr = reinterpret_cast<T *>(input->buffer());
|
T *input_ptr = reinterpret_cast<T *>(input.getBuffer());
|
||||||
T *output_ptr = reinterpret_cast<T *>(output->buffer());
|
T *output_ptr = reinterpret_cast<T *>(output->buffer());
|
||||||
|
|
||||||
const int batch_size = input->sizeAt(0);
|
const int batch_size = input.sizeAt(0);
|
||||||
const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1);
|
const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1);
|
||||||
const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2);
|
const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2);
|
||||||
const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3);
|
const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3);
|
||||||
|
|
||||||
const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1);
|
const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1);
|
||||||
const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2);
|
const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2);
|
||||||
|
@ -93,13 +93,13 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
auto xType = input->dataType();
|
auto xType = input.dataType();
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,14 +25,14 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void _spaceTodepth_(NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
static void _spaceTodepth_(const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
auto input_ptr = reinterpret_cast<T *>(input->buffer());
|
auto input_ptr = reinterpret_cast<T *>(input.getBuffer());
|
||||||
auto output_ptr = reinterpret_cast<T *>(output->buffer());
|
auto output_ptr = reinterpret_cast<T *>(output->buffer());
|
||||||
|
|
||||||
const int batch_size = input->sizeAt(0);
|
const int batch_size = input.sizeAt(0);
|
||||||
const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1);
|
const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1);
|
||||||
const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2);
|
const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2);
|
||||||
const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3);
|
const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3);
|
||||||
|
|
||||||
const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1);
|
const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1);
|
||||||
const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2);
|
const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2);
|
||||||
|
@ -97,11 +97,11 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,20 +88,20 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void __depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
static void __depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
depthToSpaceKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC);
|
depthToSpaceKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC);
|
||||||
}
|
}
|
||||||
|
|
||||||
void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
auto xType = input->dataType();
|
auto xType = input.dataType();
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {&input});
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {&input});
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,17 +90,17 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void _spaceTodepth_(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
static void _spaceTodepth_(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
spaceToDepthKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC);
|
spaceToDepthKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC);
|
||||||
}
|
}
|
||||||
|
|
||||||
void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) {
|
void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {&input});
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES);
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {&input});
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext *context, const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC);
|
void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -24,7 +24,7 @@
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC);
|
void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -151,7 +151,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
|
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
|
||||||
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
|
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
|
||||||
|
|
||||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
||||||
|
@ -213,7 +213,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
|
||||||
mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md);
|
mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md);
|
||||||
|
|
||||||
// dLdI
|
// dLdI
|
||||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
|
@ -242,7 +242,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||||
|
|
||||||
// dLdO
|
// dLdO
|
||||||
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||||
|
|
||||||
// mean
|
// mean
|
||||||
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
|
||||||
|
@ -316,7 +316,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5
|
stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5
|
||||||
|
|
||||||
// dfdm / N
|
// dfdm / N
|
||||||
auto dfdm = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes);
|
auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes);
|
||||||
dfdm *= stdInv;
|
dfdm *= stdInv;
|
||||||
dfdm *= -Ninv;
|
dfdm *= -Ninv;
|
||||||
|
|
||||||
|
@ -327,7 +327,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
|
|
||||||
// (2/N)*dfdv
|
// (2/N)*dfdv
|
||||||
NDArray dfdv(variance); // empty array with same shape as variance
|
NDArray dfdv(variance); // empty array with same shape as variance
|
||||||
(xMinusMean * *dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes);
|
(xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes);
|
||||||
dfdv *= stdInv*stdInv*stdInv;
|
dfdv *= stdInv*stdInv*stdInv;
|
||||||
dfdv *= -Ninv;
|
dfdv *= -Ninv;
|
||||||
|
|
||||||
|
@ -661,7 +661,10 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
||||||
|
|
||||||
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
if (shape::strideDescendingCAscendingF(dLdO->shapeInfo()))
|
||||||
|
batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
||||||
|
else
|
||||||
|
batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
|
||||||
|
|
||||||
*dLdM = 0;
|
*dLdM = 0;
|
||||||
*dLdV = 0;
|
*dLdV = 0;
|
||||||
|
|
|
@ -1826,4 +1826,78 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBatchNormBpNHWC(){
|
||||||
|
//Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled
|
||||||
|
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3);
|
||||||
|
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
|
||||||
|
INDArray epsStrided = eps.permute(1,0,2,3).dup().permute(1,0,2,3);
|
||||||
|
INDArray mean = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray var = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
|
||||||
|
|
||||||
|
assertEquals(eps, epsStrided);
|
||||||
|
|
||||||
|
INDArray out1eps = in.like();
|
||||||
|
INDArray out1m = mean.like();
|
||||||
|
INDArray out1v = var.like();
|
||||||
|
|
||||||
|
INDArray out2eps = in.like();
|
||||||
|
INDArray out2m = mean.like();
|
||||||
|
INDArray out2v = var.like();
|
||||||
|
|
||||||
|
DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp")
|
||||||
|
.addInputs(in, mean, var, gamma, beta, eps)
|
||||||
|
.addOutputs(out1eps, out1m, out1v)
|
||||||
|
.addIntegerArguments(1, 1, 3)
|
||||||
|
.addFloatingPointArguments(1e-5)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
DynamicCustomOp op2 = DynamicCustomOp.builder("batchnorm_bp")
|
||||||
|
.addInputs(in, mean, var, gamma, beta, epsStrided)
|
||||||
|
.addOutputs(out2eps, out2m, out2v)
|
||||||
|
.addIntegerArguments(1, 1, 3)
|
||||||
|
.addFloatingPointArguments(1e-5)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.exec(op1);
|
||||||
|
Nd4j.exec(op2);
|
||||||
|
|
||||||
|
assertEquals(out1eps, out2eps); //Fails here
|
||||||
|
assertEquals(out1m, out2m);
|
||||||
|
assertEquals(out1v, out2v);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSpaceToDepthBadStrides(){
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6);
|
||||||
|
INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3);
|
||||||
|
assertEquals(in, inBadStrides);
|
||||||
|
|
||||||
|
System.out.println("in: " + in.shapeInfoToString());
|
||||||
|
System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString());
|
||||||
|
|
||||||
|
|
||||||
|
INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3);
|
||||||
|
INDArray out2 = out.like();
|
||||||
|
|
||||||
|
|
||||||
|
CustomOp op1 = DynamicCustomOp.builder("space_to_depth")
|
||||||
|
.addInputs(in)
|
||||||
|
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1
|
||||||
|
.addOutputs(out)
|
||||||
|
.build();
|
||||||
|
Nd4j.exec(op1);
|
||||||
|
|
||||||
|
CustomOp op2 = DynamicCustomOp.builder("space_to_depth")
|
||||||
|
.addInputs(inBadStrides)
|
||||||
|
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1
|
||||||
|
.addOutputs(out2)
|
||||||
|
.build();
|
||||||
|
Nd4j.exec(op2);
|
||||||
|
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue