From e0babb58f9e24c56fe3391c89c1e8f54b1f3bedb Mon Sep 17 00:00:00 2001 From: AbdelRauf Date: Wed, 17 Mar 2021 15:48:09 +0100 Subject: [PATCH] tensormmul_bp: possible fix of shape mismatch failure and wrong assumption of equal ranks Signed-off-by: AbdelRauf --- libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp index 159918d3c..090c0942e 100644 --- a/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp +++ b/libnd4j/include/ops/declarable/generic/blas/tensormmul.cpp @@ -141,10 +141,10 @@ CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) { std::vector axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); std::vector axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); - // rank always have to be divided by 2 + std::vector axesAdLdC, axesBdLdC; if (dLdCrank > 1) { - axesAdLdC.resize(dLdCrank / 2); + axesAdLdC.resize(axesA.size()); std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); }