tensormmul_bp: possible fix of shape mismatch failure and wrong assumption of equal ranks
Signed-off-by: AbdelRauf <rauf@konduit.ai>master
parent
bedd7c6b3a
commit
e0babb58f9
|
@ -141,10 +141,10 @@ CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
|
|||
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
|
||||
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
|
||||
|
||||
// rank always have to be divided by 2
|
||||
|
||||
std::vector<int> axesAdLdC, axesBdLdC;
|
||||
if (dLdCrank > 1) {
|
||||
axesAdLdC.resize(dLdCrank / 2);
|
||||
axesAdLdC.resize(axesA.size());
|
||||
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
|
||||
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue