Fix for certain non-ews cases (#402)

* BtS/StB/StD/DtS dup for views

Signed-off-by: raver119 <raver119@gmail.com>

* batchnorm_bp dup for views

Signed-off-by: raver119 <raver119@gmail.com>

* two java tests for bad strides

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-04-21 12:41:30 +03:00 committed by GitHub
parent 5ee37a22eb
commit a10fd4524a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 142 additions and 46 deletions

View File

@ -75,7 +75,10 @@ CUSTOM_OP_IMPL(batch_to_space, 2, 1, false, 0, 1) {
REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !"); REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !");
REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !"); REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !");
helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); if (shape::strideDescendingCAscendingF(input->shapeInfo()))
helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize);
else
helpers::batchToSpace(block.launchContext(), input->dup(), *output, cropBottom, cropTop, cropLeft, cropRight, blockSize);
return Status::OK(); return Status::OK();
} }

View File

@ -74,7 +74,10 @@ CUSTOM_OP_IMPL(batch_to_space_nd, 3, 1, false, 0, 0) {
REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !"); REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !");
} }
helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output); if (shape::strideDescendingCAscendingF(input->shapeInfo()))
helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output);
else
helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, *crop, *output);
return Status::OK(); return Status::OK();
} }

View File

@ -44,7 +44,10 @@ namespace ops {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
helpers::_depthToSpace(block.launchContext(), input, output, block_size, isNHWC); if (shape::strideDescendingCAscendingF(input->shapeInfo()))
helpers::_depthToSpace(block.launchContext(), *input, output, block_size, isNHWC);
else
helpers::_depthToSpace(block.launchContext(), input->dup(), output, block_size, isNHWC);
STORE_RESULT(output); STORE_RESULT(output);

View File

@ -36,6 +36,7 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const uint blockSize = INT_ARG(0); const uint blockSize = INT_ARG(0);
REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize); REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize);
@ -52,7 +53,10 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) {
REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !"); REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !");
helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize); if (shape::strideDescendingCAscendingF(input->shapeInfo()))
helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize);
else
helpers::spaceToBatch(block.launchContext(), input->dup(), *output, padBottom, padTop, padLeft, padRight, blockSize);
return Status::OK(); return Status::OK();
} }

View File

@ -56,7 +56,10 @@ CUSTOM_OP_IMPL(space_to_batch_nd, 3, 1, false, 0, 0) {
REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !"); REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !");
} }
helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output); if (shape::strideDescendingCAscendingF(input->shapeInfo()))
helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output);
else
helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, *padding, *output);
return Status::OK(); return Status::OK();
} }

View File

@ -51,7 +51,10 @@ namespace ops {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
helpers::_spaceTodepth(block.launchContext(), input, output, block_size, isNHWC); if (shape::strideDescendingCAscendingF(input->shapeInfo()))
helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, isNHWC);
else
helpers::_spaceTodepth(block.launchContext(), input->dup(), output, block_size, isNHWC);
return Status::OK(); return Status::OK();
} }

View File

@ -26,14 +26,14 @@ namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static void __depthToSpace(NDArray *input, NDArray *output, int block_size, bool isNHWC) { static void __depthToSpace(const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
T *input_ptr = reinterpret_cast<T *>(input->buffer()); T *input_ptr = reinterpret_cast<T *>(input.getBuffer());
T *output_ptr = reinterpret_cast<T *>(output->buffer()); T *output_ptr = reinterpret_cast<T *>(output->buffer());
const int batch_size = input->sizeAt(0); const int batch_size = input.sizeAt(0);
const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1); const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1);
const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2); const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2);
const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3); const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3);
const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1);
const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2);
@ -93,13 +93,13 @@ namespace helpers {
} }
} }
void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
auto xType = input->dataType(); auto xType = input.dataType();
BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES);
} }
} }

View File

@ -25,14 +25,14 @@ namespace sd {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static void _spaceTodepth_(NDArray *input, NDArray *output, int block_size, bool isNHWC) { static void _spaceTodepth_(const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
auto input_ptr = reinterpret_cast<T *>(input->buffer()); auto input_ptr = reinterpret_cast<T *>(input.getBuffer());
auto output_ptr = reinterpret_cast<T *>(output->buffer()); auto output_ptr = reinterpret_cast<T *>(output->buffer());
const int batch_size = input->sizeAt(0); const int batch_size = input.sizeAt(0);
const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1); const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1);
const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2); const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2);
const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3); const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3);
const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1);
const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2);
@ -97,11 +97,11 @@ namespace helpers {
} }
} }
void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES);
} }
} }

View File

@ -88,20 +88,20 @@ namespace helpers {
template <typename T> template <typename T>
static void __depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { static void __depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
depthToSpaceKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); depthToSpaceKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC);
} }
void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
auto xType = input->dataType(); auto xType = input.dataType();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {&input});
BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {&input});
} }
BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES);
} }
} }

View File

@ -90,17 +90,17 @@ namespace helpers {
} }
template <typename T> template <typename T>
static void _spaceTodepth_(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { static void _spaceTodepth_(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
spaceToDepthKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); spaceToDepthKernel<T><<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC);
} }
void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) {
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {&input});
BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {&input});
} }
BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext *context, const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES);
} }
} }

View File

@ -25,7 +25,7 @@ namespace sd {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC); void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);
} }
} }
} }

View File

@ -24,7 +24,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC); void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);
} }
} }
} }

View File

@ -151,7 +151,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) { NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
@ -213,7 +213,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md);
// dLdI // dLdI
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
@ -242,7 +242,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// dLdO // dLdO
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// mean // mean
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer());
@ -316,7 +316,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5
// dfdm / N // dfdm / N
auto dfdm = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes); auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes);
dfdm *= stdInv; dfdm *= stdInv;
dfdm *= -Ninv; dfdm *= -Ninv;
@ -327,7 +327,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// (2/N)*dfdv // (2/N)*dfdv
NDArray dfdv(variance); // empty array with same shape as variance NDArray dfdv(variance); // empty array with same shape as variance
(xMinusMean * *dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); (xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes);
dfdv *= stdInv*stdInv*stdInv; dfdv *= stdInv*stdInv*stdInv;
dfdv *= -Ninv; dfdv *= -Ninv;
@ -661,7 +661,10 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW); if (shape::strideDescendingCAscendingF(dLdO->shapeInfo()))
batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
else
batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
*dLdM = 0; *dLdM = 0;
*dLdV = 0; *dLdV = 0;

View File

@ -1826,4 +1826,78 @@ public class CustomOpsTests extends BaseNd4jTest {
} }
@Test
public void testBatchNormBpNHWC(){
//Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3);
INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape());
INDArray epsStrided = eps.permute(1,0,2,3).dup().permute(1,0,2,3);
INDArray mean = Nd4j.rand(DataType.FLOAT, 3);
INDArray var = Nd4j.rand(DataType.FLOAT, 3);
INDArray gamma = Nd4j.rand(DataType.FLOAT, 3);
INDArray beta = Nd4j.rand(DataType.FLOAT, 3);
assertEquals(eps, epsStrided);
INDArray out1eps = in.like();
INDArray out1m = mean.like();
INDArray out1v = var.like();
INDArray out2eps = in.like();
INDArray out2m = mean.like();
INDArray out2v = var.like();
DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(in, mean, var, gamma, beta, eps)
.addOutputs(out1eps, out1m, out1v)
.addIntegerArguments(1, 1, 3)
.addFloatingPointArguments(1e-5)
.build();
DynamicCustomOp op2 = DynamicCustomOp.builder("batchnorm_bp")
.addInputs(in, mean, var, gamma, beta, epsStrided)
.addOutputs(out2eps, out2m, out2v)
.addIntegerArguments(1, 1, 3)
.addFloatingPointArguments(1e-5)
.build();
Nd4j.exec(op1);
Nd4j.exec(op2);
assertEquals(out1eps, out2eps); //Fails here
assertEquals(out1m, out2m);
assertEquals(out1v, out2v);
}
@Test
public void testSpaceToDepthBadStrides(){
INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6);
INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3);
assertEquals(in, inBadStrides);
System.out.println("in: " + in.shapeInfoToString());
System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString());
INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3);
INDArray out2 = out.like();
CustomOp op1 = DynamicCustomOp.builder("space_to_depth")
.addInputs(in)
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1
.addOutputs(out)
.build();
Nd4j.exec(op1);
CustomOp op2 = DynamicCustomOp.builder("space_to_depth")
.addInputs(inBadStrides)
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1
.addOutputs(out2)
.build();
Nd4j.exec(op2);
assertEquals(out, out2);
}
} }