diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp index e915df7f0..607980f0d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp @@ -75,7 +75,10 @@ CUSTOM_OP_IMPL(batch_to_space, 2, 1, false, 0, 1) { REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !"); REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !"); - helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); + else + helpers::batchToSpace(block.launchContext(), input->dup(), *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp index 312fff7ec..f62921cc2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp @@ -74,7 +74,10 @@ CUSTOM_OP_IMPL(batch_to_space_nd, 3, 1, false, 0, 0) { REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !"); } - helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output); + else + helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, *crop, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp index 63c351c34..dcf827eb1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp @@ -44,7 +44,10 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); - helpers::_depthToSpace(block.launchContext(), input, output, block_size, isNHWC); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_depthToSpace(block.launchContext(), *input, output, block_size, isNHWC); + else + helpers::_depthToSpace(block.launchContext(), input->dup(), output, block_size, isNHWC); STORE_RESULT(output); diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp index 12b981ac2..9a1683818 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp @@ -36,6 +36,7 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) { auto output = OUTPUT_VARIABLE(0); + const uint blockSize = INT_ARG(0); REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize); @@ -52,7 +53,10 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) { REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !"); - helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize); + else + helpers::spaceToBatch(block.launchContext(), input->dup(), *output, padBottom, padTop, padLeft, padRight, blockSize); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp index a782f5b02..0b8c4152d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp @@ -56,7 +56,10 @@ CUSTOM_OP_IMPL(space_to_batch_nd, 3, 1, false, 0, 0) { REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !"); } - helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output); + else + helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, *padding, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp index 3daf62ccd..b831dce2f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp @@ -51,7 +51,10 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); - helpers::_spaceTodepth(block.launchContext(), input, output, block_size, isNHWC); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, isNHWC); + else + helpers::_spaceTodepth(block.launchContext(), input->dup(), output, block_size, isNHWC); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp index 700e5b8dd..598b3dc30 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp @@ -26,14 +26,14 @@ namespace ops { namespace helpers { template - static void __depthToSpace(NDArray *input, NDArray *output, int block_size, bool isNHWC) { - T *input_ptr = reinterpret_cast(input->buffer()); + static void __depthToSpace(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + T *input_ptr = reinterpret_cast(input.getBuffer()); T *output_ptr = reinterpret_cast(output->buffer()); - const int batch_size = input->sizeAt(0); - const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); @@ -93,13 +93,13 @@ namespace helpers { } } - void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input->dataType(); + void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp index 32968b486..5668ea422 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp @@ -25,14 +25,14 @@ namespace sd { namespace ops { namespace helpers { template - static void _spaceTodepth_(NDArray *input, NDArray *output, int block_size, bool isNHWC) { - auto input_ptr = reinterpret_cast(input->buffer()); + static void _spaceTodepth_(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + auto input_ptr = reinterpret_cast(input.getBuffer()); auto output_ptr = reinterpret_cast(output->buffer()); - const int batch_size = input->sizeAt(0); - const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); @@ -97,11 +97,11 @@ namespace helpers { } } - void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES); + void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu index 35103d18b..fc3b04ee8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu @@ -88,20 +88,20 @@ namespace helpers { template - static void __depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); + static void __depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); } - void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input->dataType(); + void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); - NDArray::prepareSpecialUse({output}, {input}); + NDArray::prepareSpecialUse({output}, {&input}); BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input}); + NDArray::registerSpecialUse({output}, {&input}); } - BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu index a5ae42e78..4290a57c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu @@ -90,17 +90,17 @@ namespace helpers { } template - static void _spaceTodepth_(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); + static void _spaceTodepth_(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); } - void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input}); + void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + NDArray::prepareSpecialUse({output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {&input}); } - BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext *context, const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/d_t_s.h b/libnd4j/include/ops/declarable/helpers/d_t_s.h index 20c11ec24..e5ac58e5a 100644 --- a/libnd4j/include/ops/declarable/helpers/d_t_s.h +++ b/libnd4j/include/ops/declarable/helpers/d_t_s.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { namespace helpers { - void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC); + void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/s_t_d.h b/libnd4j/include/ops/declarable/helpers/s_t_d.h index 6dbc64f21..7ef500f03 100644 --- a/libnd4j/include/ops/declarable/helpers/s_t_d.h +++ b/libnd4j/include/ops/declarable/helpers/s_t_d.h @@ -24,7 +24,7 @@ namespace sd { namespace ops { namespace helpers { - void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC); + void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 173880e63..6ae27b42a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -151,7 +151,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray ////////////////////////////////////////////////////////////////////////// -static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, +static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights, NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) { // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x @@ -213,7 +213,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); + mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md); // dLdI dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); @@ -242,7 +242,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // dLdO - mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // mean auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); @@ -316,7 +316,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 // dfdm / N - auto dfdm = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes); + auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes); dfdm *= stdInv; dfdm *= -Ninv; @@ -327,7 +327,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // (2/N)*dfdv NDArray dfdv(variance); // empty array with same shape as variance - (xMinusMean * *dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); + (xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); dfdv *= stdInv*stdInv*stdInv; dfdv *= -Ninv; @@ -661,7 +661,10 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW); + if (shape::strideDescendingCAscendingF(dLdO->shapeInfo())) + batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW); + else + batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW); *dLdM = 0; *dLdV = 0; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 04ebaa0d7..12658ede8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1826,4 +1826,78 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test + public void testBatchNormBpNHWC(){ + //Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled + + INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3); + INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); + INDArray epsStrided = eps.permute(1,0,2,3).dup().permute(1,0,2,3); + INDArray mean = Nd4j.rand(DataType.FLOAT, 3); + INDArray var = Nd4j.rand(DataType.FLOAT, 3); + INDArray gamma = Nd4j.rand(DataType.FLOAT, 3); + INDArray beta = Nd4j.rand(DataType.FLOAT, 3); + + assertEquals(eps, epsStrided); + + INDArray out1eps = in.like(); + INDArray out1m = mean.like(); + INDArray out1v = var.like(); + + INDArray out2eps = in.like(); + INDArray out2m = mean.like(); + INDArray out2v = var.like(); + + DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp") + .addInputs(in, mean, var, gamma, beta, eps) + .addOutputs(out1eps, out1m, out1v) + .addIntegerArguments(1, 1, 3) + .addFloatingPointArguments(1e-5) + .build(); + + DynamicCustomOp op2 = DynamicCustomOp.builder("batchnorm_bp") + .addInputs(in, mean, var, gamma, beta, epsStrided) + .addOutputs(out2eps, out2m, out2v) + .addIntegerArguments(1, 1, 3) + .addFloatingPointArguments(1e-5) + .build(); + + Nd4j.exec(op1); + Nd4j.exec(op2); + + assertEquals(out1eps, out2eps); //Fails here + assertEquals(out1m, out2m); + assertEquals(out1v, out2v); + } + + @Test + public void testSpaceToDepthBadStrides(){ + INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6); + INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3); + assertEquals(in, inBadStrides); + + System.out.println("in: " + in.shapeInfoToString()); + System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString()); + + + INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3); + INDArray out2 = out.like(); + + + CustomOp op1 = DynamicCustomOp.builder("space_to_depth") + .addInputs(in) + .addIntegerArguments(2, 0) //nchw = 0, nhwc = 1 + .addOutputs(out) + .build(); + Nd4j.exec(op1); + + CustomOp op2 = DynamicCustomOp.builder("space_to_depth") + .addInputs(inBadStrides) + .addIntegerArguments(2, 0) //nchw = 0, nhwc = 1 + .addOutputs(out2) + .build(); + Nd4j.exec(op2); + + assertEquals(out, out2); + } }