tensormmul_bp: possible fix of shape mismatch failure and wrong assumption of equal ranks
Signed-off-by: AbdelRauf <rauf@konduit.ai>master
parent
bedd7c6b3a
commit
e0babb58f9
|
@ -141,10 +141,10 @@ CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
|
||||||
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
|
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
|
||||||
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
|
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
|
||||||
|
|
||||||
// rank always have to be divided by 2
|
|
||||||
std::vector<int> axesAdLdC, axesBdLdC;
|
std::vector<int> axesAdLdC, axesBdLdC;
|
||||||
if (dLdCrank > 1) {
|
if (dLdCrank > 1) {
|
||||||
axesAdLdC.resize(dLdCrank / 2);
|
axesAdLdC.resize(axesA.size());
|
||||||
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
|
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
|
||||||
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
|
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue