diff --git a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp index 159918d3c..090c0942e 100644 --- a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp @@ -141,10 +141,10 @@ CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) { std::vector axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); std::vector axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); - // rank always have to be divided by 2 + std::vector axesAdLdC, axesBdLdC; if (dLdCrank > 1) { - axesAdLdC.resize(dLdCrank / 2); + axesAdLdC.resize(axesA.size()); std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); }