/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com), created on 05.06.2018 // #ifndef LIBND4J_MMULHELPER_CPP #define LIBND4J_MMULHELPER_CPP #include "../MmulHelper.h" #include #include #include namespace nd4j { ////////////////////////////////////////////////////////////////////////// nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* A, const nd4j::NDArray* B, const std::initializer_list& axesA, const std::initializer_list& axesB) { std::vector aA(axesA); std::vector aB(axesB); return tensorDot(A, B, aA, aB); } ////////////////////////////////////////////////////////////////////////// nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector& axes_0, const std::vector& axes_1) { std::vector permutAt, permutBt; std::vector shapeAt, shapeBt; auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); NDArray* aPR = a->permute(permutAt); NDArray* bPR = b->permute(permutBt); // check whether reshape is necessary if(!aPR->isSameShape(shapeAt)) aPR->reshapei( shapeAt); if(!bPR->isSameShape(shapeBt)) bPR->reshapei( shapeBt); NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0); c->reshapei(outShape); delete aPR; delete bPR; return c; } ////////////////////////////////////////////////////////////////////////// void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector& axes_a, const std::vector& axes_b, const std::vector& permutForC) { std::vector permutAt, permutBt; std::vector shapeAt, shapeBt; ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt); NDArray *cP(c), *cPR(c); // check whether permutation is required if(!permutForC.empty()) cP = c->permute(permutForC); auto aPR = a->permute(permutAt); auto bPR = b->permute(permutBt); // check whether reshape is necessary if(!aPR->isSameShape(shapeAt)) aPR->reshapei(shapeAt); if(!bPR->isSameShape(shapeBt)) bPR->reshapei(shapeBt); if(!cP->isSameShape({aPR->sizeAt(0), bPR->sizeAt(1)})) cPR = cP->reshape(cP->ordering(), {aPR->sizeAt(0), bPR->sizeAt(1)}); mmul(aPR, bPR, cPR, 1.0, 0.0); if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer() cP->assign(cPR); if(cPR != c) delete cPR; if(cP != c) delete cP; delete aPR; delete bPR; } #ifndef __JAVACPP_HACK__ ////////////////////////////////////////////////////////////////////////// void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector>& modifA, const std::vector>& modifB, const std::vector>& modifC) { NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception for(const auto& arr : modifA) whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array for(const auto& arr : modifB) whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; for(const auto& arr : modifC) whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r"; // first step for a array if(!whatToDoWithA.empty()) aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]); // first step for b array if(!whatToDoWithB.empty()) bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]); // rest steps for a array for(int i = 1; i < whatToDoWithA.size(); ++i) if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); // rest steps for b array for(int i = 1; i < whatToDoWithB.size(); ++i) if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]); // now work with c array std::vector cArrs = {c}; if(!whatToDoWithC.empty()) { cArrs = std::vector(whatToDoWithC.size()+1, c); for(int i = 0; i < cArrs.size()-1; ++i) cArrs[i+1] = (whatToDoWithC[i] == 'p') ? cArrs[i]->permute(modifC[i]) : cArrs[i]->reshape(c->ordering(), modifC[i]); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c } mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0); // check whether new buffer allocation was happened for c array if(!whatToDoWithC.empty()) { for(int i = cArrs.size()-1; i > 0; --i) { if(cArrs[i]->getBuffer() != cArrs[i-1]->getBuffer() || cArrs[i]->getSpecialBuffer() != cArrs[i-1]->getSpecialBuffer()) cArrs[i-1]->assign(cArrs[i]); delete cArrs[i]; } } if(aPR != a) delete aPR; if(bPR != b) delete bPR; } ////////////////////////////////////////////////////////////////////////// NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector>& modifA, const std::vector>& modifB) { NDArray *aPR(const_cast(a)), *bPR(const_cast(b)); std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception for(const auto& arr : modifA) whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array for(const auto& arr : modifB) whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; // first step for a array if(!whatToDoWithA.empty()) aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]); // first step for b array if(!whatToDoWithB.empty()) bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]); // rest steps for a array for(int i = 1; i < whatToDoWithA.size(); ++i) if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); // rest steps for b array for(int i = 1; i < whatToDoWithB.size(); ++i) if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]); NDArray* result = mmul(aPR, bPR, nullptr, 1.0, 0.0); if(aPR != a) delete aPR; if(bPR != b) delete bPR; return result; } #endif ////////////////////////////////////////////////////////////////////////// NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta, const char outOrder) { const int aRank = A->rankOf(); const int bRank = B->rankOf(); // input ranks validation if(aRank > bRank && bRank != 2) throw std::runtime_error("MmulHelper::mmulNxN: rank of B array should be equal 2 !"); else if(bRank > aRank && aRank != 2) throw std::runtime_error("MmulHelper::mmulNxN: rank of A array should be equal 2 !"); else if (aRank == bRank ) { for(int i = 0; i < aRank - 2; ++i) if(A->sizeAt(i) != B->sizeAt(i)) throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); } if(A->sizeAt(-1) != B->sizeAt(-2)) throw std::runtime_error("MmulHelper::mmulNxN: shapes of A and B arrays are not suitable for matrix multiplication !"); // validation of C array std::vector cExpectedShape = aRank > bRank ? A->getShapeAsVector() : B->getShapeAsVector(); cExpectedShape[cExpectedShape.size() - 2] = A->sizeAt(-2); cExpectedShape[cExpectedShape.size() - 1] = B->sizeAt(-1); if(C != nullptr ) { if(!C->isSameShape(cExpectedShape)) throw std::runtime_error("MmulHelper::mmulNxN: shape of C array is not suitable for AxB matrix multiplication !"); } else { C = new NDArray(outOrder, cExpectedShape, B->dataType()); } // multiplication const std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(C->rankOf(), {-2, -1}); const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(C->getShapeInfo(), dimsToExclude); std::vector idxRanges(2 * C->rankOf()); // #pragma omp parallel for schedule(guided) firstprivate(idxRanges) for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { ShapeUtils::evalIdxRangesForSubArr(i, C->getShapeInfo(), dimsToExclude, idxRanges.data()); NDArray cSubArr = (*C)(idxRanges); if(aRank > bRank) { NDArray aSubArr = (*A)(idxRanges); mmulMxM(&aSubArr, B, &cSubArr, 1., 0., outOrder); } else if(bRank > aRank) { NDArray bSubArr = (*B)(idxRanges); mmulMxM(A, &bSubArr, &cSubArr, 1., 0, outOrder); } else { NDArray aSubArr = (*A)(idxRanges); NDArray bSubArr = (*B)(idxRanges); mmulMxM(&aSubArr, &bSubArr, &cSubArr, 1., 0., outOrder); } } return C; } ////////////////////////////////////////////////////////////////////////// nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, nd4j::NDArray* C , const double alpha, const double beta, const char outOrder) { int lenDim; const int aRank = A->rankOf(); const int bRank = B->rankOf(); const bool isAVector = shape::isCommonVector(A->getShapeInfo(), lenDim); const bool isBVector = shape::isCommonVector(B->getShapeInfo(), lenDim); // dot product of 2 vectors if(isAVector && isBVector && (aRank != 2 || aRank == 2 && (A->isSameShape(B) || bRank == 1 && A->sizeAt(1) == 1))) // (1x1x1 * 1x1) or (1x4 * 1*4) or (4x1 * 4x1) or (4x1 * 4) return dot(A, B, C, alpha, beta); // matrix x matrix if(aRank == 2 && bRank == 2) return mmulMxM(A, B, C, alpha, beta, outOrder); // matrix x vector if(aRank == 2 && isBVector) return mmulMxV(A, B, C, alpha, beta, outOrder); // batched matrix multiplication return mmulNxN(A, B, C, alpha, beta, outOrder); } ////////////////////////////////////////////////////////////////////////// void MmulHelper::matmul(const nd4j::NDArray* x, const nd4j::NDArray* y, nd4j::NDArray* z, const bool transX, const bool transY) { int xRank = x->rankOf(); int yRank = y->rankOf(); auto outShape = ShapeUtils::evalShapeForMatmul(x->getShapeInfo(), y->getShapeInfo(), transX, transY); if(!z->isSameShape(outShape)) { nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); throw std::invalid_argument(""); } NDArray* xT(const_cast(x)), *yT(const_cast(y)), *zT(z); if((transX && xRank > 1) || (transY && yRank > 1)) { const int rank = xRank >= yRank ? xRank : yRank; std::vector permut(rank); for (int i = 0; i < rank-2; ++i) permut[i] = i; permut[rank-2] = rank - 1; permut[rank-1] = rank - 2; if(transX) xT = x->permute(permut); if(transY) yT = y->permute(permut); } if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case xT = x->reshape(x->ordering(), {1, x->lengthOf()}); // please note x is not transposed in this case (since xRank=1) zT = z->reshape(z->ordering(), {1, z->lengthOf()}); } mmul(xT, yT, zT, 1., 0.); } else { // rest cases - batched mmul const int batchRank = xRank - 2; std::vector dimsToExclude(batchRank); for(int i = 0; i < batchRank; ++i) dimsToExclude[i] = i; const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(xT->getShapeInfo(), dimsToExclude); //PRAGMA_OMP_PARALLEL_FOR for(Nd4jLong i = 0; i < numOfSubArrs; ++i) { auto xSubArr = (*xT)(i, dimsToExclude); auto ySubArr = (*yT)(i, dimsToExclude); auto zSubArr = (*zT)(i, dimsToExclude); mmul(&xSubArr, &ySubArr, &zSubArr, 1., 0.); } } if(xT != x) delete xT; if(yT != y) delete yT; if(zT != z) delete zT; } //BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const char cOrder, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* A, const int lda, const void* B, const int ldb, const double beta, void* C, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const char aOrder, const int M, const int N, const double alpha, const void* A, const int lda, const void* B, const int incx, const double beta, void* C, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualDot, (const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); } #endif