tensormmul_bp: possible fix of shape mismatch failure and wrong assumption of equal ranks

Signed-off-by: AbdelRauf <rauf@konduit.ai>
master
AbdelRauf 2021-03-17 15:48:09 +01:00
parent bedd7c6b3a
commit e0babb58f9
1 changed files with 2 additions and 2 deletions

View File

@ -141,10 +141,10 @@ CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
// rank always have to be divided by 2
std::vector<int> axesAdLdC, axesBdLdC; std::vector<int> axesAdLdC, axesBdLdC;
if (dLdCrank > 1) { if (dLdCrank > 1) {
axesAdLdC.resize(dLdCrank / 2); axesAdLdC.resize(axesA.size());
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
} }